Spaces:
Running
on
Zero
Running
on
Zero
- sam2_mask.py +29 -6
sam2_mask.py
CHANGED
@@ -11,8 +11,11 @@ from sam2.build_sam import build_sam2
|
|
11 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
12 |
from gradio_image_prompter import ImagePrompter
|
13 |
|
|
|
|
|
|
|
14 |
def preprocess_image(image):
|
15 |
-
return image,
|
16 |
|
17 |
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
18 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
@@ -75,6 +78,28 @@ def show_points(coords, labels, ax, marker_size=375):
|
|
75 |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
76 |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
def show_box(box, ax):
|
79 |
x0, y0 = box[0], box[1]
|
80 |
w, h = box[2] - box[0], box[3] - box[1]
|
@@ -153,11 +178,10 @@ def create_sam2_tab():
|
|
153 |
|
154 |
with gr.Column():
|
155 |
gr.Markdown("# SAM2 Image Predictor")
|
156 |
-
|
157 |
-
|
158 |
-
input_image = ImagePrompter(show_label=False)
|
159 |
points_map = gr.Image(label="Points Map", type="pil", interactive=True)
|
160 |
-
|
161 |
|
162 |
with gr.Row():
|
163 |
point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
|
@@ -192,4 +216,3 @@ def create_sam2_tab():
|
|
192 |
)
|
193 |
|
194 |
return input_image, points_map, output_image
|
195 |
-
|
|
|
11 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
12 |
from gradio_image_prompter import ImagePrompter
|
13 |
|
14 |
+
# def preprocess_image(image):
|
15 |
+
# return image, gr.State([]), gr.State([]), image
|
16 |
+
|
17 |
def preprocess_image(image):
|
18 |
+
return image, [], [], image
|
19 |
|
20 |
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
21 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
|
|
78 |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
79 |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
80 |
|
81 |
+
|
82 |
+
def show_mask(mask, ax, random_color=False, borders=True):
|
83 |
+
if random_color:
|
84 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
85 |
+
else:
|
86 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
87 |
+
h, w = mask.shape[-2:]
|
88 |
+
mask = mask.astype(np.uint8)
|
89 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
90 |
+
if borders:
|
91 |
+
import cv2
|
92 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
93 |
+
# Try to smooth contours
|
94 |
+
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
95 |
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
96 |
+
ax.imshow(mask_image)
|
97 |
+
def show_points(coords, labels, ax, marker_size=375):
|
98 |
+
pos_points = coords[labels == 1]
|
99 |
+
neg_points = coords[labels == 0]
|
100 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
101 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
102 |
+
|
103 |
def show_box(box, ax):
|
104 |
x0, y0 = box[0], box[1]
|
105 |
w, h = box[2] - box[0], box[3] - box[1]
|
|
|
178 |
|
179 |
with gr.Column():
|
180 |
gr.Markdown("# SAM2 Image Predictor")
|
181 |
+
gr.Markdown("1. Upload your image\n2. Click points to mask\n3. Submit")
|
182 |
+
|
|
|
183 |
points_map = gr.Image(label="Points Map", type="pil", interactive=True)
|
184 |
+
input_image = gr.Image(type="pil", visible=False) # Original image
|
185 |
|
186 |
with gr.Row():
|
187 |
point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
|
|
|
216 |
)
|
217 |
|
218 |
return input_image, points_map, output_image
|
|