LPX55 commited on
Commit
72284aa
·
verified ·
1 Parent(s): 054a32e

Update sam2_mask.py

Browse files
Files changed (1) hide show
  1. sam2_mask.py +139 -153
sam2_mask.py CHANGED
@@ -1,59 +1,42 @@
1
- import spaces
2
  import gradio as gr
3
  import os
4
- os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
5
  import torch
6
  import numpy as np
7
  import cv2
 
8
  import matplotlib.pyplot as plt
9
- from PIL import Image, ImageFilter
10
  from sam2.build_sam import build_sam2
11
  from sam2.sam2_image_predictor import SAM2ImagePredictor
12
- from gradio_image_prompter import ImagePrompter
13
 
14
- # def preprocess_image(image):
15
- # return image, gr.State([]), gr.State([]), image
 
16
 
17
  def preprocess_image(image):
18
- return image, [], [], image
19
 
20
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
21
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
22
-
23
- # Extract x, y coordinates from evt.index
24
- x, y = evt.index
25
-
26
- # Add the point as [x, y]
27
- tracking_points.append([x, y])
28
- print(f"TRACKING POINTS: {tracking_points}")
29
-
30
  if point_type == "include":
31
- trackings_input_label.append(1)
32
  elif point_type == "exclude":
33
- trackings_input_label.append(0)
34
- print(f"TRACKING INPUT LABELS: {trackings_input_label}")
35
-
36
- # Open the image and get its dimensions
37
  transparent_background = Image.open(first_frame_path).convert('RGBA')
38
  w, h = transparent_background.size
39
-
40
- # Define the circle radius as a fraction of the smaller dimension
41
- fraction = 0.02 # You can adjust this value as needed
42
  radius = int(fraction * min(w, h))
43
-
44
- # Create a transparent layer to draw on
45
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
46
-
47
- for index, track in enumerate(tracking_points):
48
- if trackings_input_label[index] == 1:
49
- cv2.circle(transparent_layer, tuple(track), radius, (0, 255, 0, 255), -1)
50
  else:
51
- cv2.circle(transparent_layer, tuple(track), radius, (255, 0, 0, 255), -1)
52
-
53
- # Convert the transparent layer back to an image
54
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
55
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
56
-
57
  return tracking_points, trackings_input_label, selected_point_map
58
 
59
  def show_mask(mask, ax, random_color=False, borders=True):
@@ -63,156 +46,159 @@ def show_mask(mask, ax, random_color=False, borders=True):
63
  color = np.array([30/255, 144/255, 255/255, 0.6])
64
  h, w = mask.shape[-2:]
65
  mask = mask.astype(np.uint8)
66
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
67
  if borders:
68
- import cv2
69
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
70
- # Try to smooth contours
71
  contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
72
  mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
73
  ax.imshow(mask_image)
74
 
75
- def show_points(coords, labels, ax, marker_size=375):
76
- pos_points = coords[labels == 1]
77
- neg_points = coords[labels == 0]
78
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
79
- ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
80
-
81
-
82
- def show_mask(mask, ax, random_color=False, borders=True):
83
- if random_color:
84
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
85
- else:
86
- color = np.array([30/255, 144/255, 255/255, 0.6])
87
- h, w = mask.shape[-2:]
88
- mask = mask.astype(np.uint8)
89
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
90
- if borders:
91
- import cv2
92
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
93
- # Try to smooth contours
94
- contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
95
- mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
96
- ax.imshow(mask_image)
97
- def show_points(coords, labels, ax, marker_size=375):
98
- pos_points = coords[labels == 1]
99
- neg_points = coords[labels == 0]
100
- ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
101
- ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
102
 
103
  def show_box(box, ax):
104
  x0, y0 = box[0], box[1]
105
  w, h = box[2] - box[0], box[3] - box[1]
106
- ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
107
 
108
  def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
109
- combined_images = [] # List to store filenames of images with masks overlaid
110
- mask_images = [] # List to store filenames of separate mask images
111
  for i, (mask, score) in enumerate(zip(masks, scores)):
112
- # ---- Original Image with Mask Overlaid ----
113
  plt.figure(figsize=(10, 10))
114
  plt.imshow(image)
115
- show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
116
- if box_coords is not None:
117
- show_box(box_coords, plt.gca())
118
- if len(scores) > 1:
119
- plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
120
  plt.axis('off')
121
- # Save the figure as a JPG file
122
  combined_filename = f"combined_image_{i+1}.jpg"
123
  plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
124
  combined_images.append(combined_filename)
125
- plt.close() # Close the figure to free up memory
126
- # ---- Separate Mask Image (White Mask on Black Background) ----
127
- # Create a black image
128
  mask_image = np.zeros_like(image, dtype=np.uint8)
129
- # The mask is a binary array where the masked area is 1, else 0.
130
- # Convert the mask to a white color in the mask_image
131
  mask_layer = (mask > 0).astype(np.uint8) * 255
132
- for c in range(3): # Assuming RGB, repeat mask for all channels
133
  mask_image[:, :, c] = mask_layer
134
- # Save the mask image
135
  mask_filename = f"mask_image_{i+1}.png"
136
  Image.fromarray(mask_image).save(mask_filename)
137
  mask_images.append(mask_filename)
138
- plt.close() # Close the figure to free up memory
139
  return combined_images, mask_images
140
 
141
- @spaces.GPU()
142
- def sam_process(original_image, points, labels):
143
- print(f"Points: {points}")
144
- print(f"Labels: {labels}")
145
- if not points or not labels:
146
- print("No points or labels provided, returning None")
147
- return None
148
-
149
- # Convert image to numpy array for SAM2 processing
150
- image = np.array(original_image)
151
- predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
152
- predictor.set_image(image)
153
-
154
- input_point = np.array(points)
155
- input_label = np.array(labels)
156
-
157
- try:
158
- masks, scores, _ = predictor.predict(input_point, input_label, multimask_output=False)
159
- except Exception as e:
160
- print(f"Error during prediction: {e}")
161
- return None
162
-
163
- sorted_indices = np.argsort(scores)[::-1]
164
- masks = masks[sorted_indices]
165
-
166
- if masks and len(masks) > 0:
167
- mask = masks[0] * 255
168
- mask_image = Image.fromarray(mask.astype(np.uint8))
169
- return mask_image
170
  else:
171
- print("No masks generated, returning None")
172
- return None
173
 
174
- def create_sam2_tab():
175
- first_frame = gr.State() # Tracks original image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  tracking_points = gr.State([])
177
  trackings_input_label = gr.State([])
178
-
179
  with gr.Column():
180
- gr.Markdown("# SAM2 Image Predictor")
181
- gr.Markdown("1. Upload your image\n2. Click points to mask\n3. Submit")
182
-
183
- points_map = gr.Image(label="Points Map", type="pil", interactive=True)
184
- input_image = gr.Image(type="pil", visible=False) # Original image
185
-
186
  with gr.Row():
187
- point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
188
- clear_button = gr.Button("Clear Points")
189
-
190
- submit_button = gr.Button("Submit")
191
- output_image = gr.Image("Segmented Output")
192
-
193
- # Event handlers
194
- points_map.upload(
195
- lambda img: (img, img, [], []),
196
- inputs=points_map,
197
- outputs=[input_image, first_frame, tracking_points, trackings_input_label]
198
- )
199
-
200
- clear_button.click(
201
- lambda img: ([], [], img),
202
- inputs=first_frame,
203
- outputs=[tracking_points, trackings_input_label, points_map]
204
- )
205
-
206
- points_map.select(
207
- get_point,
208
- inputs=[point_type, tracking_points, trackings_input_label, first_frame],
209
- outputs=[tracking_points, trackings_input_label, points_map]
210
- )
211
-
212
- submit_button.click(
213
- sam_process,
214
- inputs=[input_image, tracking_points, trackings_input_label],
215
- outputs=output_image
216
- )
217
-
218
- return input_image, points_map, output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import os
 
3
  import torch
4
  import numpy as np
5
  import cv2
6
+ import huggingface_hub
7
  import matplotlib.pyplot as plt
8
+ from PIL import Image
9
  from sam2.build_sam import build_sam2
10
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
11
 
12
+
13
+ # Remove all CUDA-specific configurations
14
+ torch.autocast(device_type="cpu", dtype=torch.float32).__enter__()
15
 
16
  def preprocess_image(image):
17
+ return image, gr.State([]), gr.State([]), image
18
 
19
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
20
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
21
+ tracking_points.value.append(evt.index)
22
+ print(f"TRACKING POINT: {tracking_points.value}")
 
 
 
 
 
 
23
  if point_type == "include":
24
+ trackings_input_label.value.append(1)
25
  elif point_type == "exclude":
26
+ trackings_input_label.value.append(0)
27
+ print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
 
 
28
  transparent_background = Image.open(first_frame_path).convert('RGBA')
29
  w, h = transparent_background.size
30
+ fraction = 0.02
 
 
31
  radius = int(fraction * min(w, h))
 
 
32
  transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
33
+ for index, track in enumerate(tracking_points.value):
34
+ if trackings_input_label.value[index] == 1:
35
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
 
36
  else:
37
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
 
 
38
  transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
39
  selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
 
40
  return tracking_points, trackings_input_label, selected_point_map
41
 
42
  def show_mask(mask, ax, random_color=False, borders=True):
 
46
  color = np.array([30/255, 144/255, 255/255, 0.6])
47
  h, w = mask.shape[-2:]
48
  mask = mask.astype(np.uint8)
49
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
50
  if borders:
51
+ contours, _= cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
 
 
52
  contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
53
  mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
54
  ax.imshow(mask_image)
55
 
56
+ def show_points(coords, labels, ax, marker_size=200):
57
+ pos_points = coords[labels==1]
58
+ neg_points = coords[labels==0]
59
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
60
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def show_box(box, ax):
63
  x0, y0 = box[0], box[1]
64
  w, h = box[2] - box[0], box[3] - box[1]
65
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
66
 
67
  def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
68
+ combined_images = []
69
+ mask_images = []
70
  for i, (mask, score) in enumerate(zip(masks, scores)):
 
71
  plt.figure(figsize=(10, 10))
72
  plt.imshow(image)
73
+ show_mask(mask, plt.gca(), borders=borders)
 
 
 
 
74
  plt.axis('off')
 
75
  combined_filename = f"combined_image_{i+1}.jpg"
76
  plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
77
  combined_images.append(combined_filename)
78
+ plt.close()
 
 
79
  mask_image = np.zeros_like(image, dtype=np.uint8)
 
 
80
  mask_layer = (mask > 0).astype(np.uint8) * 255
81
+ for c in range(3):
82
  mask_image[:, :, c] = mask_layer
 
83
  mask_filename = f"mask_image_{i+1}.png"
84
  Image.fromarray(mask_image).save(mask_filename)
85
  mask_images.append(mask_filename)
 
86
  return combined_images, mask_images
87
 
88
+ def expand_contract_mask(mask, px, expand=True):
89
+ kernel = np.ones((px, px), np.uint8)
90
+ if expand:
91
+ return cv2.dilate(mask, kernel, iterations=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  else:
93
+ return cv2.erode(mask, kernel, iterations=1)
 
94
 
95
+ def feather_mask(mask, feather_size=10):
96
+ feathered_mask = mask.copy()
97
+ Feathered_region = mask > 0
98
+ Feathered_region = cv2.dilate(Feathered_region.astype(np.uint8), np.ones((feather_size, feather_size), np.uint8), iterations=1)
99
+ Feathered_region = Feathered_region & (~mask.astype(bool))
100
+
101
+ for i in range(1, feather_size + 1):
102
+ weight = i / (feather_size + 1)
103
+ feathered_mask[Feathered_region] = feathered_mask[Feathered_region] * (1 - weight) + weight
104
+
105
+ return feathered_mask
106
+
107
+ def process_mask(mask, expand_contract_px, expand, feathering_enabled, feather_size):
108
+ if expand_contract_px > 0:
109
+ mask = expand_contract_mask(mask, expand_contract_px, expand)
110
+ if feathering_enabled:
111
+ mask = feather_mask(mask, feather_size)
112
+ return mask
113
+
114
+ def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
115
+ image = Image.open(input_image)
116
+ image = np.array(image.convert("RGB"))
117
+ sam21_hfmap = {
118
+ "tiny": "facebook/sam2.1-hiera-tiny",
119
+ "small": "facebook/sam2.1-hiera-small",
120
+ "base-plus": "facebook/sam2.1-hiera-base-plus",
121
+ "large": "facebook/sam2.1-hiera-large",
122
+ }
123
+ # sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
124
+ # Use CPU for both model and computations
125
+ # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
126
+ predictor = SAM2ImagePredictor.from_pretrained(sam21_hfmap[checkpoint], device="cpu")
127
+
128
+ # predictor = SAM2ImagePredictor(sam2_model)
129
+ predictor.set_image(image)
130
+ input_point = np.array(tracking_points.value)
131
+ input_label = np.array(trackings_input_label.value)
132
+ masks, scores, logits = predictor.predict(
133
+ point_coords=input_point,
134
+ point_labels=input_label,
135
+ multimask_output=False,
136
+ )
137
+ sorted_ind = np.argsort(scores)[::-1]
138
+ masks = masks[sorted_ind]
139
+ scores = scores[sorted_ind]
140
+ processed_masks = []
141
+ for mask in masks:
142
+ processed_mask = process_mask(mask, expand_contract_px, expand, feathering_enabled, feather_size)
143
+ processed_masks.append(processed_mask)
144
+ results, mask_results = show_masks(image, processed_masks, scores,
145
+ point_coords=input_point,
146
+ input_labels=input_label,
147
+ borders=True)
148
+ return results[0], mask_results[0]
149
+
150
+ with gr.Blocks() as demo:
151
+ first_frame_path = gr.State()
152
  tracking_points = gr.State([])
153
  trackings_input_label = gr.State([])
 
154
  with gr.Column():
155
+ gr.Markdown("# SAM2 Image Predictor (CPU Version)")
156
+ gr.Markdown("This version runs entirely on CPU")
 
 
 
 
157
  with gr.Row():
158
+ with gr.Column():
159
+ input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
160
+ points_map = gr.Image(label="points map", type="filepath", interactive=True)
161
+ with gr.Row():
162
+ point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
163
+ clear_points_btn = gr.Button("Clear Points")
164
+ checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="base-plus")
165
+ with gr.Row():
166
+ expand_contract_px = gr.Slider(minimum=0, maximum=50, value=0, label="Expand/Contract (pixels)")
167
+ expand = gr.Radio(["Expand", "Contract"], value="Expand", label="Action")
168
+ with gr.Row():
169
+ feathering_enabled = gr.Checkbox(value=False, label="Enable Feathering")
170
+ feather_size = gr.Slider(minimum=1, maximum=50, value=10, label="Feathering Size", visible=False)
171
+ submit_btn = gr.Button("Submit")
172
+ with gr.Column():
173
+ output_result = gr.Image()
174
+ output_result_mask = gr.Image()
175
+ clear_points_btn.click(
176
+ fn=preprocess_image,
177
+ inputs=input_image,
178
+ outputs=[first_frame_path, tracking_points, trackings_input_label, points_map],
179
+ queue=False
180
+ )
181
+ points_map.upload(
182
+ fn=preprocess_image,
183
+ inputs=[points_map],
184
+ outputs=[first_frame_path, tracking_points, trackings_input_label, input_image],
185
+ queue=False
186
+ )
187
+ points_map.select(
188
+ fn=get_point,
189
+ inputs=[point_type, tracking_points, trackings_input_label, first_frame_path],
190
+ outputs=[tracking_points, trackings_input_label, points_map],
191
+ queue=False
192
+ )
193
+ submit_btn.click(
194
+ fn=sam_process,
195
+ inputs=[input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size],
196
+ outputs=[output_result, output_result_mask]
197
+ )
198
+ feathering_enabled.change(
199
+ fn=lambda enabled: gr.update(visible=enabled),
200
+ inputs=[feathering_enabled],
201
+ outputs=[feather_size]
202
+ )
203
+
204
+ demo.launch(show_error=True)