LPX55 commited on
Commit
ea0d88d
·
1 Parent(s): 2314f9b
Files changed (1) hide show
  1. sam2_mask.py +29 -6
sam2_mask.py CHANGED
@@ -11,8 +11,11 @@ from sam2.build_sam import build_sam2
11
  from sam2.sam2_image_predictor import SAM2ImagePredictor
12
  from gradio_image_prompter import ImagePrompter
13
 
 
 
 
14
  def preprocess_image(image):
15
- return image, gr.State([]), gr.State([]), image
16
 
17
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
18
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
@@ -75,6 +78,28 @@ def show_points(coords, labels, ax, marker_size=375):
75
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
76
  ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def show_box(box, ax):
79
  x0, y0 = box[0], box[1]
80
  w, h = box[2] - box[0], box[3] - box[1]
@@ -153,11 +178,10 @@ def create_sam2_tab():
153
 
154
  with gr.Column():
155
  gr.Markdown("# SAM2 Image Predictor")
156
-
157
- image_input = gr.State()
158
- input_image = ImagePrompter(show_label=False)
159
  points_map = gr.Image(label="Points Map", type="pil", interactive=True)
160
- # image_input = gr.Image(type="pil", visible=False) # Original image
161
 
162
  with gr.Row():
163
  point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
@@ -192,4 +216,3 @@ def create_sam2_tab():
192
  )
193
 
194
  return input_image, points_map, output_image
195
-
 
11
  from sam2.sam2_image_predictor import SAM2ImagePredictor
12
  from gradio_image_prompter import ImagePrompter
13
 
14
+ # def preprocess_image(image):
15
+ # return image, gr.State([]), gr.State([]), image
16
+
17
  def preprocess_image(image):
18
+ return image, [], [], image
19
 
20
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
21
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
 
78
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
79
  ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
80
 
81
+
82
+ def show_mask(mask, ax, random_color=False, borders=True):
83
+ if random_color:
84
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
85
+ else:
86
+ color = np.array([30/255, 144/255, 255/255, 0.6])
87
+ h, w = mask.shape[-2:]
88
+ mask = mask.astype(np.uint8)
89
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
90
+ if borders:
91
+ import cv2
92
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
93
+ # Try to smooth contours
94
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
95
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
96
+ ax.imshow(mask_image)
97
+ def show_points(coords, labels, ax, marker_size=375):
98
+ pos_points = coords[labels == 1]
99
+ neg_points = coords[labels == 0]
100
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
101
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
102
+
103
  def show_box(box, ax):
104
  x0, y0 = box[0], box[1]
105
  w, h = box[2] - box[0], box[3] - box[1]
 
178
 
179
  with gr.Column():
180
  gr.Markdown("# SAM2 Image Predictor")
181
+ gr.Markdown("1. Upload your image\n2. Click points to mask\n3. Submit")
182
+
 
183
  points_map = gr.Image(label="Points Map", type="pil", interactive=True)
184
+ input_image = gr.Image(type="pil", visible=False) # Original image
185
 
186
  with gr.Row():
187
  point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
 
216
  )
217
 
218
  return input_image, points_map, output_image