amaanwanie commited on
Commit
c969ea5
·
verified ·
1 Parent(s): fc08985

added multi-segmentation support

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -105,26 +105,37 @@ def detection_fn(image, prompt):
105
  def segmentation_fn(image, prompt):
106
  image_np = np.array(image)
107
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
 
 
108
  detections, _ = dino_model.predict_with_caption(
109
  image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
110
  )
 
111
  boxes = detections.xyxy
112
  sam_predictor.set_image(image_np)
113
- masks, scores, _ = sam_predictor.predict(box=boxes, multimask_output=True)
114
- if masks is None or len(masks) == 0:
 
 
 
 
 
 
 
115
  raise ValueError("No masks found")
116
- mask = masks[np.argmax(scores)]
 
 
117
 
118
- # Visualize mask
119
  def overlay_mask(mask, image):
120
- color = np.concatenate([np.random.random(3), np.array([0.8])])
121
- h, w = mask.shape[-2:]
122
- mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
123
- image_pil = Image.fromarray(image).convert("RGBA")
124
- mask_pil = Image.fromarray((mask_img * 255).astype(np.uint8)).convert("RGBA")
125
- return np.array(Image.alpha_composite(image_pil, mask_pil))
126
-
127
- return overlay_mask(mask, image_np)
128
 
129
  def inpainting_fn(image, prompt):
130
  image_np = np.array(image)
 
105
  def segmentation_fn(image, prompt):
106
  image_np = np.array(image)
107
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
108
+
109
+ # Detect with Grounding DINO
110
  detections, _ = dino_model.predict_with_caption(
111
  image=image_cv, caption=prompt, box_threshold=0.35, text_threshold=0.25
112
  )
113
+
114
  boxes = detections.xyxy
115
  sam_predictor.set_image(image_np)
116
+
117
+ all_masks = []
118
+ for box in boxes:
119
+ box = box.reshape(1, 4)
120
+ masks, scores, _ = sam_predictor.predict(box=box, multimask_output=True)
121
+ if masks is not None:
122
+ all_masks.append(masks[np.argmax(scores)])
123
+
124
+ if not all_masks:
125
  raise ValueError("No masks found")
126
+
127
+ # Combine masks into one binary mask
128
+ merged_mask = np.any(all_masks, axis=0).astype(np.uint8) * 255
129
 
130
+ # Overlay on image
131
  def overlay_mask(mask, image):
132
+ color = np.array([0, 255, 0], dtype=np.uint8) # Green
133
+ mask_rgb = np.stack([mask] * 3, axis=-1)
134
+ overlay = np.where(mask_rgb, color, image)
135
+ return overlay
136
+
137
+ return overlay_mask(merged_mask, image_np)
138
+
 
139
 
140
  def inpainting_fn(image, prompt):
141
  image_np = np.array(image)