Update README.md
Browse files
README.md
CHANGED
@@ -1,199 +1,275 @@
|
|
1 |
---
|
2 |
library_name: transformers
|
3 |
-
|
4 |
---
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
[
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
[
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
library_name: transformers
|
3 |
+
license: apache-2.0
|
4 |
---
|
5 |
|
6 |
+
# SynthPose (Transformers 🤗 VitPose Huge variant)
|
7 |
+
|
8 |
+
The SynthPose model was proposed in [OpenCapBench: A Benchmark to Bridge Pose Estimation and Biomechanics](https://arxiv.org/abs/2406.09788) by Yoni Gozlan, Antoine Falisse, Scott Uhlrich, Anthony Gatti, Michael Black, Akshay Chaudhari.
|
9 |
+
|
10 |
+
# Intended use cases
|
11 |
+
|
12 |
+
This model uses a VitPose Huge backbone.
|
13 |
+
SynthPose is a new approach that enables finetuning of pre-trained 2D human pose models to predict an arbitrarily denser set of keypoints for accurate kinematic analysis through the use of synthetic data.
|
14 |
+
More details are available in [OpenCapBench: A Benchmark to Bridge Pose Estimation and Biomechanics](https://arxiv.org/abs/2406.09788).
|
15 |
+
This particular variant was finetuned on a set of keypoints usually found on motion capture setups, and include coco keypoints as well.
|
16 |
+
|
17 |
+
The model predicts the following 52 markers:
|
18 |
+
|
19 |
+
```py
|
20 |
+
{
|
21 |
+
0: "Nose",
|
22 |
+
1: "L_Eye",
|
23 |
+
2: "R_Eye",
|
24 |
+
3: "L_Ear",
|
25 |
+
4: "R_Ear",
|
26 |
+
5: "L_Shoulder",
|
27 |
+
6: "R_Shoulder",
|
28 |
+
7: "L_Elbow",
|
29 |
+
8: "R_Elbow",
|
30 |
+
9: "L_Wrist",
|
31 |
+
10: "R_Wrist",
|
32 |
+
11: "L_Hip",
|
33 |
+
12: "R_Hip",
|
34 |
+
13: "L_Knee",
|
35 |
+
14: "R_Knee",
|
36 |
+
15: "L_Ankle",
|
37 |
+
16: "R_Ankle",
|
38 |
+
17: "sternum",
|
39 |
+
18: "rshoulder",
|
40 |
+
19: "lshoulder",
|
41 |
+
20: "r_lelbow",
|
42 |
+
21: "l_lelbow",
|
43 |
+
22: "r_melbow",
|
44 |
+
23: "l_melbow",
|
45 |
+
24: "r_lwrist",
|
46 |
+
25: "l_lwrist",
|
47 |
+
26: "r_mwrist",
|
48 |
+
27: "l_mwrist",
|
49 |
+
28: "r_ASIS",
|
50 |
+
29: "l_ASIS",
|
51 |
+
30: "r_PSIS",
|
52 |
+
31: "l_PSIS",
|
53 |
+
32: "r_knee",
|
54 |
+
33: "l_knee",
|
55 |
+
34: "r_mknee",
|
56 |
+
35: "l_mknee",
|
57 |
+
36: "r_ankle",
|
58 |
+
37: "l_ankle",
|
59 |
+
38: "r_mankle",
|
60 |
+
39: "l_mankle",
|
61 |
+
40: "r_5meta",
|
62 |
+
41: "l_5meta",
|
63 |
+
42: "r_toe",
|
64 |
+
43: "l_toe",
|
65 |
+
44: "r_big_toe",
|
66 |
+
45: "l_big_toe",
|
67 |
+
46: "l_calc",
|
68 |
+
47: "r_calc",
|
69 |
+
48: "C7",
|
70 |
+
49: "L2",
|
71 |
+
50: "T11",
|
72 |
+
51: "T6",
|
73 |
+
}
|
74 |
+
```
|
75 |
+
Where the first 17 keypoints are the COCO keypoints, and the next 35 are anatomical markers.
|
76 |
+
|
77 |
+
# Usage
|
78 |
+
|
79 |
+
## Image inference
|
80 |
+
|
81 |
+
Here's how to load the model and run inference on an image:
|
82 |
+
|
83 |
+
```py
|
84 |
+
import torch
|
85 |
+
import requests
|
86 |
+
import numpy as np
|
87 |
+
|
88 |
+
from PIL import Image
|
89 |
+
|
90 |
+
from transformers import (
|
91 |
+
AutoProcessor,
|
92 |
+
RTDetrForObjectDetection,
|
93 |
+
VitPoseForPoseEstimation,
|
94 |
+
)
|
95 |
+
|
96 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
97 |
+
|
98 |
+
url = "http://farm4.staticflickr.com/3300/3416216247_f9c6dfc939_z.jpg"
|
99 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
100 |
+
|
101 |
+
# ------------------------------------------------------------------------
|
102 |
+
# Stage 1. Detect humans on the image
|
103 |
+
# ------------------------------------------------------------------------
|
104 |
+
|
105 |
+
# You can choose detector by your choice
|
106 |
+
person_image_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
|
107 |
+
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
|
108 |
+
|
109 |
+
inputs = person_image_processor(images=image, return_tensors="pt").to(device)
|
110 |
+
|
111 |
+
with torch.no_grad():
|
112 |
+
outputs = person_model(**inputs)
|
113 |
+
|
114 |
+
results = person_image_processor.post_process_object_detection(
|
115 |
+
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
|
116 |
+
)
|
117 |
+
result = results[0] # take first image results
|
118 |
+
|
119 |
+
# Human label refers 0 index in COCO dataset
|
120 |
+
person_boxes = result["boxes"][result["labels"] == 0]
|
121 |
+
person_boxes = person_boxes.cpu().numpy()
|
122 |
+
|
123 |
+
# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
|
124 |
+
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
|
125 |
+
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
|
126 |
+
|
127 |
+
# ------------------------------------------------------------------------
|
128 |
+
# Stage 2. Detect keypoints for each person found
|
129 |
+
# ------------------------------------------------------------------------
|
130 |
+
|
131 |
+
image_processor = AutoProcessor.from_pretrained("yonigozlan/synthpose-vitpose-huge-hf")
|
132 |
+
model = VitPoseForPoseEstimation.from_pretrained("yonigozlan/synthpose-vitpose-huge-hf", device_map=device)
|
133 |
+
|
134 |
+
inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
|
135 |
+
|
136 |
+
with torch.no_grad():
|
137 |
+
outputs = model(**inputs)
|
138 |
+
|
139 |
+
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
|
140 |
+
image_pose_result = pose_results[0] # results for first image
|
141 |
+
```
|
142 |
+
|
143 |
+
### Visualization for supervision user
|
144 |
+
|
145 |
+
```py
|
146 |
+
import supervision as sv
|
147 |
+
|
148 |
+
xy = torch.stack([pose_result['keypoints'] for pose_result in image_pose_result]).cpu().numpy()
|
149 |
+
scores = torch.stack([pose_result['scores'] for pose_result in image_pose_result]).cpu().numpy()
|
150 |
+
|
151 |
+
key_points = sv.KeyPoints(
|
152 |
+
xy=xy, confidence=scores
|
153 |
+
)
|
154 |
+
|
155 |
+
vertex_annotator = sv.VertexAnnotator(
|
156 |
+
color=sv.Color.PINK,
|
157 |
+
radius=2
|
158 |
+
)
|
159 |
+
|
160 |
+
annotated_frame = vertex_annotator.annotate(
|
161 |
+
scene=image.copy(),
|
162 |
+
key_points=key_points
|
163 |
+
)
|
164 |
+
annotated_frame
|
165 |
+
```
|
166 |
+
|
167 |
+
<p>
|
168 |
+
<img src="vitpose_sv.png" width=375>
|
169 |
+
</p>
|
170 |
+
|
171 |
+
### Advanced manual visualization
|
172 |
+
```py
|
173 |
+
import math
|
174 |
+
import cv2
|
175 |
+
|
176 |
+
def draw_points(image, keypoints, scores, pose_keypoint_color, keypoint_score_threshold, radius, show_keypoint_weight):
|
177 |
+
if pose_keypoint_color is not None:
|
178 |
+
assert len(pose_keypoint_color) == len(keypoints)
|
179 |
+
for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
|
180 |
+
x_coord, y_coord = int(kpt[0]), int(kpt[1])
|
181 |
+
if kpt_score > keypoint_score_threshold:
|
182 |
+
color = tuple(int(c) for c in pose_keypoint_color[kid])
|
183 |
+
if show_keypoint_weight:
|
184 |
+
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
|
185 |
+
transparency = max(0, min(1, kpt_score))
|
186 |
+
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
|
187 |
+
else:
|
188 |
+
cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
|
189 |
+
|
190 |
+
def draw_links(image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
|
191 |
+
height, width, _ = image.shape
|
192 |
+
if keypoint_edges is not None and link_colors is not None:
|
193 |
+
assert len(link_colors) == len(keypoint_edges)
|
194 |
+
for sk_id, sk in enumerate(keypoint_edges):
|
195 |
+
x1, y1, score1 = (int(keypoints[sk[0], 0]), int(keypoints[sk[0], 1]), scores[sk[0]])
|
196 |
+
x2, y2, score2 = (int(keypoints[sk[1], 0]), int(keypoints[sk[1], 1]), scores[sk[1]])
|
197 |
+
if (
|
198 |
+
x1 > 0
|
199 |
+
and x1 < width
|
200 |
+
and y1 > 0
|
201 |
+
and y1 < height
|
202 |
+
and x2 > 0
|
203 |
+
and x2 < width
|
204 |
+
and y2 > 0
|
205 |
+
and y2 < height
|
206 |
+
and score1 > keypoint_score_threshold
|
207 |
+
and score2 > keypoint_score_threshold
|
208 |
+
):
|
209 |
+
color = tuple(int(c) for c in link_colors[sk_id])
|
210 |
+
if show_keypoint_weight:
|
211 |
+
X = (x1, x2)
|
212 |
+
Y = (y1, y2)
|
213 |
+
mean_x = np.mean(X)
|
214 |
+
mean_y = np.mean(Y)
|
215 |
+
length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
|
216 |
+
angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
|
217 |
+
polygon = cv2.ellipse2Poly(
|
218 |
+
(int(mean_x), int(mean_y)), (int(length / 2), int(stick_width)), int(angle), 0, 360, 1
|
219 |
+
)
|
220 |
+
cv2.fillConvexPoly(image, polygon, color)
|
221 |
+
transparency = max(0, min(1, 0.5 * (keypoints[sk[0], 2] + keypoints[sk[1], 2])))
|
222 |
+
cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
|
223 |
+
else:
|
224 |
+
cv2.line(image, (x1, y1), (x2, y2), color, thickness=thickness)
|
225 |
+
|
226 |
+
|
227 |
+
# Note: keypoint_edges and color palette are dataset-specific
|
228 |
+
keypoint_edges = model.config.edges
|
229 |
+
|
230 |
+
palette = np.array(
|
231 |
+
[
|
232 |
+
[255, 128, 0],
|
233 |
+
[255, 153, 51],
|
234 |
+
[255, 178, 102],
|
235 |
+
[230, 230, 0],
|
236 |
+
[255, 153, 255],
|
237 |
+
[153, 204, 255],
|
238 |
+
[255, 102, 255],
|
239 |
+
[255, 51, 255],
|
240 |
+
[102, 178, 255],
|
241 |
+
[51, 153, 255],
|
242 |
+
[255, 153, 153],
|
243 |
+
[255, 102, 102],
|
244 |
+
[255, 51, 51],
|
245 |
+
[153, 255, 153],
|
246 |
+
[102, 255, 102],
|
247 |
+
[51, 255, 51],
|
248 |
+
[0, 255, 0],
|
249 |
+
[0, 0, 255],
|
250 |
+
[255, 0, 0],
|
251 |
+
[255, 255, 255],
|
252 |
+
]
|
253 |
+
)
|
254 |
+
|
255 |
+
link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
|
256 |
+
keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]+[4]*(52-17)]
|
257 |
+
|
258 |
+
numpy_image = np.array(image)
|
259 |
+
|
260 |
+
for pose_result in image_pose_result:
|
261 |
+
scores = np.array(pose_result["scores"])
|
262 |
+
keypoints = np.array(pose_result["keypoints"])
|
263 |
+
|
264 |
+
# draw each point on image
|
265 |
+
draw_points(numpy_image, keypoints, scores, keypoint_colors, keypoint_score_threshold=0.3, radius=2, show_keypoint_weight=False)
|
266 |
+
|
267 |
+
# draw links
|
268 |
+
draw_links(numpy_image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold=0.3, thickness=1, show_keypoint_weight=False)
|
269 |
+
|
270 |
+
pose_image = Image.fromarray(numpy_image)
|
271 |
+
pose_image
|
272 |
+
```
|
273 |
+
<p>
|
274 |
+
<img src="vitpose_manual.png" width=375>
|
275 |
+
</p>
|