LPX55 commited on
Commit
2314f9b
·
1 Parent(s): fd14388
Files changed (2) hide show
  1. app.py +5 -2
  2. sam2_mask.py +195 -0
app.py CHANGED
@@ -11,6 +11,7 @@ from gradio_image_prompter import ImagePrompter
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
14
  import subprocess
15
 
16
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
@@ -557,6 +558,8 @@ with gr.Blocks(css=css, fill_height=True) as demo:
557
  use_as_input_button_outpaint = gr.Button("Use as Input Image", visible=False)
558
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
559
  preview_image = gr.Image(label="Preview")
 
 
560
  with gr.TabItem("SAM2 Mask"):
561
  gr.Markdown("# Object Segmentation with SAM2")
562
  gr.Markdown(
@@ -571,9 +574,9 @@ with gr.Blocks(css=css, fill_height=True) as demo:
571
  upload_image_input = ImagePrompter(show_label=False)
572
  with gr.Column():
573
  image_output = gr.Image(label="Segmented Image", type="pil", height=400)
574
-
575
  # Button to trigger the prediction
576
- predict_button = gr.Button("Predict Mask")
577
 
578
  # Define the action triggered by the predict button
579
  predict_button.click(
 
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
  from sam2.sam2_image_predictor import SAM2ImagePredictor
14
+ from sam2_mask import create_sam2_tab
15
  import subprocess
16
 
17
  subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
 
558
  use_as_input_button_outpaint = gr.Button("Use as Input Image", visible=False)
559
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
560
  preview_image = gr.Image(label="Preview")
561
+ with gr.TabItem("SAM2 Masking"):
562
+ input_image, points_map, output_result_mask = create_sam2_tab()
563
  with gr.TabItem("SAM2 Mask"):
564
  gr.Markdown("# Object Segmentation with SAM2")
565
  gr.Markdown(
 
574
  upload_image_input = ImagePrompter(show_label=False)
575
  with gr.Column():
576
  image_output = gr.Image(label="Segmented Image", type="pil", height=400)
577
+ with gr.Row():
578
  # Button to trigger the prediction
579
+ predict_button = gr.Button("Predict Mask")
580
 
581
  # Define the action triggered by the predict button
582
  predict_button.click(
sam2_mask.py CHANGED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import os
4
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
5
+ import torch
6
+ import numpy as np
7
+ import cv2
8
+ import matplotlib.pyplot as plt
9
+ from PIL import Image, ImageFilter
10
+ from sam2.build_sam import build_sam2
11
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
12
+ from gradio_image_prompter import ImagePrompter
13
+
14
+ def preprocess_image(image):
15
+ return image, gr.State([]), gr.State([]), image
16
+
17
+ def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
18
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
19
+
20
+ # Extract x, y coordinates from evt.index
21
+ x, y = evt.index
22
+
23
+ # Add the point as [x, y]
24
+ tracking_points.append([x, y])
25
+ print(f"TRACKING POINTS: {tracking_points}")
26
+
27
+ if point_type == "include":
28
+ trackings_input_label.append(1)
29
+ elif point_type == "exclude":
30
+ trackings_input_label.append(0)
31
+ print(f"TRACKING INPUT LABELS: {trackings_input_label}")
32
+
33
+ # Open the image and get its dimensions
34
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
35
+ w, h = transparent_background.size
36
+
37
+ # Define the circle radius as a fraction of the smaller dimension
38
+ fraction = 0.02 # You can adjust this value as needed
39
+ radius = int(fraction * min(w, h))
40
+
41
+ # Create a transparent layer to draw on
42
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
43
+
44
+ for index, track in enumerate(tracking_points):
45
+ if trackings_input_label[index] == 1:
46
+ cv2.circle(transparent_layer, tuple(track), radius, (0, 255, 0, 255), -1)
47
+ else:
48
+ cv2.circle(transparent_layer, tuple(track), radius, (255, 0, 0, 255), -1)
49
+
50
+ # Convert the transparent layer back to an image
51
+ transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
52
+ selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
53
+
54
+ return tracking_points, trackings_input_label, selected_point_map
55
+
56
+ def show_mask(mask, ax, random_color=False, borders=True):
57
+ if random_color:
58
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
59
+ else:
60
+ color = np.array([30/255, 144/255, 255/255, 0.6])
61
+ h, w = mask.shape[-2:]
62
+ mask = mask.astype(np.uint8)
63
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
64
+ if borders:
65
+ import cv2
66
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
67
+ # Try to smooth contours
68
+ contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
69
+ mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
70
+ ax.imshow(mask_image)
71
+
72
+ def show_points(coords, labels, ax, marker_size=375):
73
+ pos_points = coords[labels == 1]
74
+ neg_points = coords[labels == 0]
75
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
76
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
77
+
78
+ def show_box(box, ax):
79
+ x0, y0 = box[0], box[1]
80
+ w, h = box[2] - box[0], box[3] - box[1]
81
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
82
+
83
+ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
84
+ combined_images = [] # List to store filenames of images with masks overlaid
85
+ mask_images = [] # List to store filenames of separate mask images
86
+ for i, (mask, score) in enumerate(zip(masks, scores)):
87
+ # ---- Original Image with Mask Overlaid ----
88
+ plt.figure(figsize=(10, 10))
89
+ plt.imshow(image)
90
+ show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
91
+ if box_coords is not None:
92
+ show_box(box_coords, plt.gca())
93
+ if len(scores) > 1:
94
+ plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
95
+ plt.axis('off')
96
+ # Save the figure as a JPG file
97
+ combined_filename = f"combined_image_{i+1}.jpg"
98
+ plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
99
+ combined_images.append(combined_filename)
100
+ plt.close() # Close the figure to free up memory
101
+ # ---- Separate Mask Image (White Mask on Black Background) ----
102
+ # Create a black image
103
+ mask_image = np.zeros_like(image, dtype=np.uint8)
104
+ # The mask is a binary array where the masked area is 1, else 0.
105
+ # Convert the mask to a white color in the mask_image
106
+ mask_layer = (mask > 0).astype(np.uint8) * 255
107
+ for c in range(3): # Assuming RGB, repeat mask for all channels
108
+ mask_image[:, :, c] = mask_layer
109
+ # Save the mask image
110
+ mask_filename = f"mask_image_{i+1}.png"
111
+ Image.fromarray(mask_image).save(mask_filename)
112
+ mask_images.append(mask_filename)
113
+ plt.close() # Close the figure to free up memory
114
+ return combined_images, mask_images
115
+
116
+ @spaces.GPU()
117
+ def sam_process(original_image, points, labels):
118
+ print(f"Points: {points}")
119
+ print(f"Labels: {labels}")
120
+ if not points or not labels:
121
+ print("No points or labels provided, returning None")
122
+ return None
123
+
124
+ # Convert image to numpy array for SAM2 processing
125
+ image = np.array(original_image)
126
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
127
+ predictor.set_image(image)
128
+
129
+ input_point = np.array(points)
130
+ input_label = np.array(labels)
131
+
132
+ try:
133
+ masks, scores, _ = predictor.predict(input_point, input_label, multimask_output=False)
134
+ except Exception as e:
135
+ print(f"Error during prediction: {e}")
136
+ return None
137
+
138
+ sorted_indices = np.argsort(scores)[::-1]
139
+ masks = masks[sorted_indices]
140
+
141
+ if masks and len(masks) > 0:
142
+ mask = masks[0] * 255
143
+ mask_image = Image.fromarray(mask.astype(np.uint8))
144
+ return mask_image
145
+ else:
146
+ print("No masks generated, returning None")
147
+ return None
148
+
149
+ def create_sam2_tab():
150
+ first_frame = gr.State() # Tracks original image
151
+ tracking_points = gr.State([])
152
+ trackings_input_label = gr.State([])
153
+
154
+ with gr.Column():
155
+ gr.Markdown("# SAM2 Image Predictor")
156
+
157
+ image_input = gr.State()
158
+ input_image = ImagePrompter(show_label=False)
159
+ points_map = gr.Image(label="Points Map", type="pil", interactive=True)
160
+ # image_input = gr.Image(type="pil", visible=False) # Original image
161
+
162
+ with gr.Row():
163
+ point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
164
+ clear_button = gr.Button("Clear Points")
165
+
166
+ submit_button = gr.Button("Submit")
167
+ output_image = gr.Image("Segmented Output")
168
+
169
+ # Event handlers
170
+ points_map.upload(
171
+ lambda img: (img, img, [], []),
172
+ inputs=points_map,
173
+ outputs=[input_image, first_frame, tracking_points, trackings_input_label]
174
+ )
175
+
176
+ clear_button.click(
177
+ lambda img: ([], [], img),
178
+ inputs=first_frame,
179
+ outputs=[tracking_points, trackings_input_label, points_map]
180
+ )
181
+
182
+ points_map.select(
183
+ get_point,
184
+ inputs=[point_type, tracking_points, trackings_input_label, first_frame],
185
+ outputs=[tracking_points, trackings_input_label, points_map]
186
+ )
187
+
188
+ submit_button.click(
189
+ sam_process,
190
+ inputs=[input_image, tracking_points, trackings_input_label],
191
+ outputs=output_image
192
+ )
193
+
194
+ return input_image, points_map, output_image
195
+