Spaces:
Sleeping
Sleeping
Commit
·
354d315
1
Parent(s):
0a9ad49
Add legend to prediction overlay with task color boxes and labels
Browse files- inference_utils/model.py +38 -1
inference_utils/model.py
CHANGED
@@ -14,7 +14,7 @@ from dataclasses import dataclass
|
|
14 |
import os
|
15 |
from typing import Tuple
|
16 |
|
17 |
-
from PIL import Image
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
import matplotlib.pyplot as plt
|
20 |
import numpy as np
|
@@ -113,6 +113,43 @@ def predict(
|
|
113 |
masks = [1 * (pred_mask[i] > 0.5) for i in range(len(pred_tasks_found))]
|
114 |
pred_overlay = overlay_masks(image, masks, colors)
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
return pred_overlay, pred_tasks_not_found
|
117 |
|
118 |
|
|
|
14 |
import os
|
15 |
from typing import Tuple
|
16 |
|
17 |
+
from PIL import Image, ImageDraw
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
import matplotlib.pyplot as plt
|
20 |
import numpy as np
|
|
|
113 |
masks = [1 * (pred_mask[i] > 0.5) for i in range(len(pred_tasks_found))]
|
114 |
pred_overlay = overlay_masks(image, masks, colors)
|
115 |
|
116 |
+
# Add legend
|
117 |
+
if len(pred_tasks_found) > 0:
|
118 |
+
# Convert to numpy for manipulation
|
119 |
+
pred_overlay = np.array(pred_overlay)
|
120 |
+
|
121 |
+
# Calculate legend dimensions
|
122 |
+
legend_height = 30 * len(pred_tasks_found) # 30 pixels per entry
|
123 |
+
legend_padding = 10 # padding around legend
|
124 |
+
total_height = pred_overlay.shape[0] + legend_height + 2 * legend_padding
|
125 |
+
|
126 |
+
# Create new image with space for legend
|
127 |
+
new_image = np.zeros((total_height, pred_overlay.shape[1], 3), dtype=np.uint8)
|
128 |
+
new_image[: pred_overlay.shape[0], :] = pred_overlay
|
129 |
+
new_image[pred_overlay.shape[0] :] = 255 # White background for legend
|
130 |
+
|
131 |
+
# Convert to PIL once for all legend entries
|
132 |
+
img_pil = Image.fromarray(new_image)
|
133 |
+
draw = ImageDraw.Draw(img_pil)
|
134 |
+
|
135 |
+
# Draw legend entries
|
136 |
+
start_y = pred_overlay.shape[0] + legend_padding
|
137 |
+
for i, task in enumerate(pred_tasks_found):
|
138 |
+
# Draw color box
|
139 |
+
box_x = 10
|
140 |
+
box_y = start_y + i * 30
|
141 |
+
box_size = 20
|
142 |
+
box_coords = (box_x, box_y, box_x + box_size, box_y + box_size)
|
143 |
+
draw.rectangle(box_coords, fill=colors[i])
|
144 |
+
|
145 |
+
# Draw text (vertically centered with color box)
|
146 |
+
text_y = box_y + (box_size - 12) // 2 # Assuming ~12px text height
|
147 |
+
draw.text((box_x + box_size + 10, text_y), task.target, fill=(0, 0, 0))
|
148 |
+
|
149 |
+
pred_overlay = img_pil
|
150 |
+
else:
|
151 |
+
pred_overlay = Image.fromarray(np.array(pred_overlay))
|
152 |
+
|
153 |
return pred_overlay, pred_tasks_not_found
|
154 |
|
155 |
|