LPX55 commited on
Commit
f7c1ee7
·
1 Parent(s): f131e81

attemp was made

Browse files
Files changed (1) hide show
  1. sam2_mask.py +56 -102
sam2_mask.py CHANGED
@@ -13,32 +13,20 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
13
  def preprocess_image(image):
14
  return image, gr.State([]), gr.State([]), image
15
 
16
- def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
17
- print(f"You selected {evt.value} at {evt.index} from {evt.target}")
18
- tracking_points.value.append(evt.index)
19
- print(f"TRACKING POINT: {tracking_points.value}")
20
- if point_type == "include":
21
- trackings_input_label.value.append(1)
22
- elif point_type == "exclude":
23
- trackings_input_label.value.append(0)
24
- print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
25
- # Open the image and get its dimensions
26
- transparent_background = Image.open(first_frame_path).convert('RGBA')
27
- w, h = transparent_background.size
28
- # Define the circle radius as a fraction of the smaller dimension
29
- fraction = 0.02 # You can adjust this value as needed
30
- radius = int(fraction * min(w, h))
31
- # Create a transparent layer to draw on
32
- transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
33
- for index, track in enumerate(tracking_points.value):
34
- if trackings_input_label.value[index] == 1:
35
- cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
36
- else:
37
- cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
38
- # Convert the transparent layer back to an image
39
- transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
40
- selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
41
- return tracking_points, trackings_input_label, selected_point_map
42
 
43
  # use bfloat16 for the entire notebook
44
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
@@ -108,96 +96,62 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
108
  return combined_images, mask_images
109
 
110
  @spaces.GPU()
111
- def sam_process(input_image, tracking_points, trackings_input_label):
112
- image = Image.open(input_image)
113
- image = np.array(image.convert("RGB"))
114
- # if checkpoint == "tiny":
115
- # sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
116
- # model_cfg = "sam2_hiera_t.yaml"
117
- # elif checkpoint == "small":
118
- # sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
119
- # model_cfg = "sam2_hiera_s.yaml"
120
- # elif checkpoint == "base-plus":
121
- # sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt"
122
- # model_cfg = "sam2_hiera_b+.yaml"
123
- # elif checkpoint == "large":
124
- # sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
125
- # model_cfg = "sam2_hiera_l.yaml"
126
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
127
- # print(predictor)
128
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
129
- predictor.set_image(image)
130
- input_point = np.array(tracking_points.value)
131
- input_label = np.array(trackings_input_label.value)
132
- print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
133
- masks, scores, logits = predictor.predict(
134
- point_coords=input_point,
135
- point_labels=input_label,
136
- multimask_output=False,
137
- )
138
- sorted_ind = np.argsort(scores)[::-1]
139
- masks = masks[sorted_ind]
140
- scores = scores[sorted_ind]
141
- logits = logits[sorted_ind]
142
- print(masks.shape)
143
- results, mask_results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
144
- print(results)
145
- return results[0], mask_results[0]
146
  # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
147
  # predictor = SAM2ImagePredictor(sam2_model)
148
 
149
 
150
-
151
  def create_sam2_tab():
152
- first_frame_path = gr.State()
153
  tracking_points = gr.State([])
154
  trackings_input_label = gr.State([])
 
155
  with gr.Column():
156
  gr.Markdown("# SAM2 Image Predictor")
157
- gr.Markdown("This is a simple demo for image segmentation with SAM2.")
158
- gr.Markdown("""Instructions:
159
- 1. Upload your image
160
- 2. With 'include' point type selected, Click on the object to mask
161
- 3. Switch to 'exclude' point type if you want to specify an area to avoid
162
- 4. Submit !
163
- """)
164
  with gr.Row():
165
- with gr.Column():
166
- input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
167
- points_map = gr.Image(
168
- label="points map",
169
- type="filepath",
170
- interactive=True
171
- )
172
- with gr.Row():
173
- point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
174
- clear_points_btn = gr.Button("Clear Points")
175
- # checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny")
176
- submit_btn = gr.Button("Submit")
177
- with gr.Column():
178
- output_result = gr.Image()
179
- output_result_mask = gr.Image()
180
- clear_points_btn.click(
181
- fn=preprocess_image,
182
- inputs=input_image,
183
- outputs=[first_frame_path, tracking_points, trackings_input_label, points_map],
184
- queue=False
185
- )
186
  points_map.upload(
187
- fn=preprocess_image,
188
- inputs=[points_map],
189
- outputs=[first_frame_path, tracking_points, trackings_input_label, input_image],
190
- queue=False
 
 
 
 
191
  )
192
  points_map.select(
193
- fn=get_point,
194
- inputs=[point_type, tracking_points, trackings_input_label, first_frame_path],
195
- outputs=[tracking_points, trackings_input_label, points_map],
196
- queue=False
197
  )
198
- submit_btn.click(
199
- fn=sam_process,
200
  inputs=[input_image, tracking_points, trackings_input_label],
201
- outputs=[output_result, output_result_mask]
202
  )
203
- return input_image, points_map, output_result_mask
 
 
13
  def preprocess_image(image):
14
  return image, gr.State([]), gr.State([]), image
15
 
16
+ def get_point(point_type, tracking_points, trackings_input_label, original_image, evt):
17
+ x, y = evt.index
18
+ tracking_points.append((x, y))
19
+ trackings_input_label.append(1 if point_type == "include" else 0)
20
+
21
+ # Redraw all points on original image
22
+ w, h = original_image.size
23
+ radius = int(min(w, h) * 0.02)
24
+ img = original_image.convert("RGBA")
25
+ draw = ImageDraw.Draw(img)
26
+ for i, (cx, cy) in enumerate(tracking_points):
27
+ color = (0, 255, 0, 255) if trackings_input_label[i] == 1 else (255, 0, 0, 255)
28
+ draw.ellipse([cx-radius, cy-radius, cx+radius, cy+radius], fill=color)
29
+ return tracking_points, trackings_input_label, img
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # use bfloat16 for the entire notebook
32
  torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
 
96
  return combined_images, mask_images
97
 
98
  @spaces.GPU()
99
+ def sam_process(original_image, points, labels):
100
+ # Convert image to numpy array for SAM2 processing
101
+ image = np.array(original_image)
 
 
 
 
 
 
 
 
 
 
 
 
102
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
103
+ predictor.set_image(image)
104
+ input_point = np.array(points)
105
+ input_label = np.array(labels)
106
+ masks, scores, _ = predictor.predict(input_point, input_label, multimask_output=False)
107
+ sorted_indices = np.argsort(scores)[::-1]
108
+ masks = masks[sorted_indices]
109
+
110
+ # Generate mask image
111
+ mask = masks[0] * 255
112
+ mask_image = Image.fromarray(mask.astype(np.uint8))
113
+ return mask_image
 
 
 
 
 
 
 
 
114
  # sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
115
  # predictor = SAM2ImagePredictor(sam2_model)
116
 
117
 
 
118
  def create_sam2_tab():
119
+ first_frame = gr.State() # Tracks original image
120
  tracking_points = gr.State([])
121
  trackings_input_label = gr.State([])
122
+
123
  with gr.Column():
124
  gr.Markdown("# SAM2 Image Predictor")
125
+ gr.Markdown("1. Upload your image\n2. Click points to mask\n3. Submit")
126
+ points_map = gr.Image(label="Points Map", type="pil", interactive=True)
127
+ input_image = gr.Image(type="pil", visible=False) # Original image
128
+
 
 
 
129
  with gr.Row():
130
+ point_type = gr.Radio(["include", "exclude"], value="include", label="Point Type")
131
+ clear_button = gr.Button("Clear Points")
132
+ submit_button = gr.Button("Submit")
133
+ output_image = gr.Image("Segmented Output")
134
+
135
+ # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  points_map.upload(
137
+ lambda img: (img, img, [], []),
138
+ inputs=points_map,
139
+ outputs=[input_image, first_frame, tracking_points, trackings_input_label]
140
+ )
141
+ clear_button.click(
142
+ lambda img: ([], [], img),
143
+ inputs=first_frame,
144
+ outputs=[tracking_points, trackings_input_label, points_map]
145
  )
146
  points_map.select(
147
+ get_point,
148
+ inputs=[point_type, tracking_points, trackings_input_label, first_frame],
149
+ outputs=[tracking_points, trackings_input_label, points_map]
 
150
  )
151
+ submit_button.click(
152
+ sam_process,
153
  inputs=[input_image, tracking_points, trackings_input_label],
154
+ outputs=output_image
155
  )
156
+
157
+ return input_image, points_map, output_image