kernel-luso-comfort commited on
Commit
354d315
·
1 Parent(s): 0a9ad49

Add legend to prediction overlay with task color boxes and labels

Browse files
Files changed (1) hide show
  1. inference_utils/model.py +38 -1
inference_utils/model.py CHANGED
@@ -14,7 +14,7 @@ from dataclasses import dataclass
14
  import os
15
  from typing import Tuple
16
 
17
- from PIL import Image
18
  from huggingface_hub import hf_hub_download
19
  import matplotlib.pyplot as plt
20
  import numpy as np
@@ -113,6 +113,43 @@ def predict(
113
  masks = [1 * (pred_mask[i] > 0.5) for i in range(len(pred_tasks_found))]
114
  pred_overlay = overlay_masks(image, masks, colors)
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  return pred_overlay, pred_tasks_not_found
117
 
118
 
 
14
  import os
15
  from typing import Tuple
16
 
17
+ from PIL import Image, ImageDraw
18
  from huggingface_hub import hf_hub_download
19
  import matplotlib.pyplot as plt
20
  import numpy as np
 
113
  masks = [1 * (pred_mask[i] > 0.5) for i in range(len(pred_tasks_found))]
114
  pred_overlay = overlay_masks(image, masks, colors)
115
 
116
+ # Add legend
117
+ if len(pred_tasks_found) > 0:
118
+ # Convert to numpy for manipulation
119
+ pred_overlay = np.array(pred_overlay)
120
+
121
+ # Calculate legend dimensions
122
+ legend_height = 30 * len(pred_tasks_found) # 30 pixels per entry
123
+ legend_padding = 10 # padding around legend
124
+ total_height = pred_overlay.shape[0] + legend_height + 2 * legend_padding
125
+
126
+ # Create new image with space for legend
127
+ new_image = np.zeros((total_height, pred_overlay.shape[1], 3), dtype=np.uint8)
128
+ new_image[: pred_overlay.shape[0], :] = pred_overlay
129
+ new_image[pred_overlay.shape[0] :] = 255 # White background for legend
130
+
131
+ # Convert to PIL once for all legend entries
132
+ img_pil = Image.fromarray(new_image)
133
+ draw = ImageDraw.Draw(img_pil)
134
+
135
+ # Draw legend entries
136
+ start_y = pred_overlay.shape[0] + legend_padding
137
+ for i, task in enumerate(pred_tasks_found):
138
+ # Draw color box
139
+ box_x = 10
140
+ box_y = start_y + i * 30
141
+ box_size = 20
142
+ box_coords = (box_x, box_y, box_x + box_size, box_y + box_size)
143
+ draw.rectangle(box_coords, fill=colors[i])
144
+
145
+ # Draw text (vertically centered with color box)
146
+ text_y = box_y + (box_size - 12) // 2 # Assuming ~12px text height
147
+ draw.text((box_x + box_size + 10, text_y), task.target, fill=(0, 0, 0))
148
+
149
+ pred_overlay = img_pil
150
+ else:
151
+ pred_overlay = Image.fromarray(np.array(pred_overlay))
152
+
153
  return pred_overlay, pred_tasks_not_found
154
 
155