Spaces:
Running
Running
| import argparse | |
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| import random | |
| os.system('python setup.py develop') | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import ImageDraw, Image | |
| from matplotlib import pyplot as plt | |
| from mmcv import Config | |
| from mmcv.runner import load_checkpoint | |
| from mmpose.core import wrap_fp16_model | |
| from mmpose.models import build_posenet | |
| from torchvision import transforms | |
| from demo import Resize_Pad | |
| from models import * | |
| import matplotlib | |
| matplotlib.use('agg') | |
| def plot_results(support_img, query_img, support_kp, support_w, query_kp, | |
| query_w, skeleton, | |
| initial_proposals, prediction, radius=6): | |
| h, w, c = support_img.shape | |
| prediction = prediction[-1].cpu().numpy() * h | |
| query_img = (query_img - np.min(query_img)) / ( | |
| np.max(query_img) - np.min(query_img)) | |
| for id, (img, w, keypoint) in enumerate(zip([query_img], | |
| [query_w], | |
| [prediction])): | |
| f, axes = plt.subplots() | |
| plt.imshow(img) | |
| for k in range(keypoint.shape[0]): | |
| if w[k] > 0: | |
| kp = keypoint[k, :2] | |
| c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6) | |
| patch = plt.Circle(kp, radius, color=c) | |
| axes.add_patch(patch) | |
| axes.text(kp[0], kp[1], k) | |
| plt.draw() | |
| for l, limb in enumerate(skeleton): | |
| kp = keypoint[:, :2] | |
| if l > len(COLORS) - 1: | |
| c = [x / 255 for x in random.sample(range(0, 255), 3)] | |
| else: | |
| c = [x / 255 for x in COLORS[l]] | |
| if w[limb[0]] > 0 and w[limb[1]] > 0: | |
| patch = plt.Line2D([kp[limb[0], 0], kp[limb[1], 0]], | |
| [kp[limb[0], 1], kp[limb[1], 1]], | |
| linewidth=6, color=c, alpha=0.6) | |
| axes.add_artist(patch) | |
| plt.axis('off') # command for hiding the axis. | |
| plt.subplots_adjust(0, 0, 1, 1, 0, 0) | |
| return plt | |
| COLORS = [ | |
| [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], | |
| [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], | |
| [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], | |
| [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0] | |
| ] | |
| kp_src = [] | |
| skeleton = [] | |
| count = 0 | |
| color_idx = 0 | |
| prev_pt = None | |
| prev_pt_idx = None | |
| prev_clicked = None | |
| original_support_image = None | |
| checkpoint_path = '' | |
| def process(query_img, | |
| cfg_path='configs/demo_b.py'): | |
| global skeleton | |
| cfg = Config.fromfile(cfg_path) | |
| kp_src_np = np.array(kp_src).copy().astype(np.float32) | |
| kp_src_np[:, 0] = kp_src_np[:, 0] / 128. * cfg.model.encoder_config.img_size | |
| kp_src_np[:, 1] = kp_src_np[:, 1] / 128. * cfg.model.encoder_config.img_size | |
| kp_src_np = np.flip(kp_src_np, 1).copy() | |
| kp_src_tensor = torch.tensor(kp_src_np).float() | |
| preprocess = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| Resize_Pad(cfg.model.encoder_config.img_size, | |
| cfg.model.encoder_config.img_size)]) | |
| if len(skeleton) == 0: | |
| skeleton = [(0, 0)] | |
| support_img = preprocess(original_support_image).flip(0)[None] | |
| np_query = np.array(query_img)[:, :, ::-1].copy() | |
| q_img = preprocess(np_query).flip(0)[None] | |
| # Create heatmap from keypoints | |
| genHeatMap = TopDownGenerateTargetFewShot() | |
| data_cfg = cfg.data_cfg | |
| data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, | |
| cfg.model.encoder_config.img_size]) | |
| data_cfg['joint_weights'] = None | |
| data_cfg['use_different_joint_weights'] = False | |
| kp_src_3d = torch.concatenate( | |
| (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) | |
| kp_src_3d_weight = torch.concatenate( | |
| (torch.ones_like(kp_src_tensor), | |
| torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) | |
| target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, | |
| kp_src_3d, | |
| kp_src_3d_weight, | |
| sigma=1) | |
| target_s = torch.tensor(target_s).float()[None] | |
| target_weight_s = torch.ones_like( | |
| torch.tensor(target_weight_s).float()[None]) | |
| data = { | |
| 'img_s': [support_img], | |
| 'img_q': q_img, | |
| 'target_s': [target_s], | |
| 'target_weight_s': [target_weight_s], | |
| 'target_q': None, | |
| 'target_weight_q': None, | |
| 'return_loss': False, | |
| 'img_metas': [{'sample_skeleton': [skeleton], | |
| 'query_skeleton': skeleton, | |
| 'sample_joints_3d': [kp_src_3d], | |
| 'query_joints_3d': kp_src_3d, | |
| 'sample_center': [kp_src_tensor.mean(dim=0)], | |
| 'query_center': kp_src_tensor.mean(dim=0), | |
| 'sample_scale': [ | |
| kp_src_tensor.max(dim=0)[0] - | |
| kp_src_tensor.min(dim=0)[0]], | |
| 'query_scale': kp_src_tensor.max(dim=0)[0] - | |
| kp_src_tensor.min(dim=0)[0], | |
| 'sample_rotation': [0], | |
| 'query_rotation': 0, | |
| 'sample_bbox_score': [1], | |
| 'query_bbox_score': 1, | |
| 'query_image_file': '', | |
| 'sample_image_file': [''], | |
| }] | |
| } | |
| # Load model | |
| model = build_posenet(cfg.model) | |
| fp16_cfg = cfg.get('fp16', None) | |
| if fp16_cfg is not None: | |
| wrap_fp16_model(model) | |
| load_checkpoint(model, checkpoint_path, map_location='cpu') | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**data) | |
| # visualize results | |
| vis_s_weight = target_weight_s[0] | |
| vis_q_weight = target_weight_s[0] | |
| vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0) | |
| vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0) | |
| support_kp = kp_src_3d | |
| out = plot_results(vis_s_image, | |
| vis_q_image, | |
| support_kp, | |
| vis_s_weight, | |
| None, | |
| vis_q_weight, | |
| skeleton, | |
| None, | |
| torch.tensor(outputs['points']).squeeze(0), | |
| ) | |
| return out | |
| with gr.Blocks() as demo: | |
| gr.Markdown(''' | |
| # Pose Anything Demo | |
| We present a novel approach to category agnostic pose estimation that leverages the inherent geometrical relations between keypoints through a newly designed Graph Transformer Decoder. By capturing and incorporating this crucial structural information, our method enhances the accuracy of keypoint localization, marking a significant departure from conventional CAPE techniques that treat keypoints as isolated entities. | |
| ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](https://github.com/orhir/PoseAnything) | |
|  | |
| ## Instructions | |
| 1. Upload an image of the object you want to pose on the **left** image. | |
| 2. Click on the **left** image to mark keypoints. | |
| 3. Click on the keypoints on the **right** image to mark limbs. | |
| 4. Upload an image of the object you want to pose to the query image (**bottom**). | |
| 5. Click **Evaluate** to pose the query image. | |
| ''') | |
| with gr.Row(): | |
| support_img = gr.Image(label="Support Image", | |
| type="pil", | |
| info='Click to mark keypoints').style( | |
| height=256, width=256) | |
| posed_support = gr.Image(label="Posed Support Image", | |
| type="pil", | |
| interactive=False).style(height=256, width=256) | |
| with gr.Row(): | |
| query_img = gr.Image(label="Query Image", | |
| type="pil").style(height=256, width=256) | |
| with gr.Row(): | |
| eval_btn = gr.Button(value="Evaluate") | |
| with gr.Row(): | |
| output_img = gr.Plot(label="Output Image", height=256, width=256) | |
| def get_select_coords(kp_support, | |
| limb_support, | |
| evt: gr.SelectData, | |
| r=0.015): | |
| pixels_in_queue = set() | |
| pixels_in_queue.add((evt.index[1], evt.index[0])) | |
| while len(pixels_in_queue) > 0: | |
| pixel = pixels_in_queue.pop() | |
| if pixel[0] is not None and pixel[ | |
| 1] is not None and pixel not in kp_src: | |
| kp_src.append(pixel) | |
| else: | |
| print("Invalid pixel") | |
| if limb_support is None: | |
| canvas_limb = kp_support | |
| else: | |
| canvas_limb = limb_support | |
| canvas_kp = kp_support | |
| w, h = canvas_kp.size | |
| draw_pose = ImageDraw.Draw(canvas_kp) | |
| draw_limb = ImageDraw.Draw(canvas_limb) | |
| r = int(r * w) | |
| leftUpPoint = (pixel[1] - r, pixel[0] - r) | |
| rightDownPoint = (pixel[1] + r, pixel[0] + r) | |
| twoPointList = [leftUpPoint, rightDownPoint] | |
| draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
| draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
| return canvas_kp, canvas_limb | |
| def get_limbs(kp_support, | |
| evt: gr.SelectData, | |
| r=0.02, width=0.02): | |
| global count, color_idx, prev_pt, skeleton, prev_pt_idx, prev_clicked | |
| curr_pixel = (evt.index[1], evt.index[0]) | |
| pixels_in_queue = set() | |
| pixels_in_queue.add((evt.index[1], evt.index[0])) | |
| canvas_kp = kp_support | |
| w, h = canvas_kp.size | |
| r = int(r * w) | |
| width = int(width * w) | |
| while (len(pixels_in_queue) > 0 and | |
| curr_pixel != prev_clicked and | |
| len(kp_src) > 0): | |
| pixel = pixels_in_queue.pop() | |
| prev_clicked = pixel | |
| closest_point = min(kp_src, | |
| key=lambda p: (p[0] - pixel[0]) ** 2 + | |
| (p[1] - pixel[1]) ** 2) | |
| closest_point_index = kp_src.index(closest_point) | |
| draw_limb = ImageDraw.Draw(canvas_kp) | |
| if color_idx < len(COLORS): | |
| c = COLORS[color_idx] | |
| else: | |
| c = random.choices(range(256), k=3) | |
| leftUpPoint = (closest_point[1] - r, closest_point[0] - r) | |
| rightDownPoint = (closest_point[1] + r, closest_point[0] + r) | |
| twoPointList = [leftUpPoint, rightDownPoint] | |
| draw_limb.ellipse(twoPointList, fill=tuple(c)) | |
| if count == 0: | |
| prev_pt = closest_point[1], closest_point[0] | |
| prev_pt_idx = closest_point_index | |
| count = count + 1 | |
| else: | |
| if prev_pt_idx != closest_point_index: | |
| # Create Line and add Limb | |
| draw_limb.line([prev_pt, (closest_point[1], closest_point[0])], | |
| fill=tuple(c), | |
| width=width) | |
| skeleton.append((prev_pt_idx, closest_point_index)) | |
| color_idx = color_idx + 1 | |
| else: | |
| draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255)) | |
| count = 0 | |
| return canvas_kp | |
| def set_query(support_img): | |
| global original_support_image | |
| skeleton.clear() | |
| kp_src.clear() | |
| original_support_image = np.array(support_img)[:, :, ::-1].copy() | |
| support_img = support_img.resize((128, 128), Image.Resampling.LANCZOS) | |
| return support_img, support_img | |
| support_img.select(get_select_coords, | |
| [support_img, posed_support], | |
| [support_img, posed_support], | |
| ) | |
| support_img.upload(set_query, | |
| inputs=support_img, | |
| outputs=[support_img,posed_support]) | |
| posed_support.select(get_limbs, | |
| posed_support, | |
| posed_support) | |
| eval_btn.click(fn=process, | |
| inputs=[query_img], | |
| outputs=output_img) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Pose Anything Demo') | |
| parser.add_argument('--checkpoint', | |
| help='checkpoint path', | |
| default='1shot-swin_graph_split1.pth') | |
| args = parser.parse_args() | |
| checkpoint_path = args.checkpoint | |
| print("Loading checkpoint from {}".format(checkpoint_path)) | |
| print(os.path.exists(checkpoint_path)) | |
| demo.launch() | |