NingKanae An-619 commited on
Commit
7273732
·
0 Parent(s):

Duplicate from An-619/FastSAM

Browse files

Co-authored-by: Yongqi An <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FastSAM
3
+ emoji: 🐠
4
+ colorFrom: pink
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.35.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: An-619/FastSAM
12
+ ---
13
+
14
+ # Fast Segment Anything
15
+
16
+ Official PyTorch Implementation of the <a href="https://github.com/CASIA-IVA-Lab/FastSAM">.
17
+
18
+ The **Fast Segment Anything Model(FastSAM)** is a CNN Segment Anything Model trained by only 2% of the SA-1B dataset published by SAM authors. The FastSAM achieve a comparable performance with
19
+ the SAM method at **50× higher run-time speed**.
20
+
21
+
22
+ ## License
23
+
24
+ The model is licensed under the [Apache 2.0 license](LICENSE).
25
+
26
+
27
+ ## Acknowledgement
28
+
29
+ - [Segment Anything](https://segment-anything.com/) provides the SA-1B dataset and the base codes.
30
+ - [YOLOv8](https://github.com/ultralytics/ultralytics) provides codes and pre-trained models.
31
+ - [YOLACT](https://arxiv.org/abs/2112.10003) provides powerful instance segmentation method.
32
+ - [Grounded-Segment-Anything](https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything) provides a useful web demo template.
33
+
34
+ ## Citing FastSAM
35
+
36
+ If you find this project useful for your research, please consider citing the following BibTeX entry.
37
+
38
+ ```
39
+ @misc{zhao2023fast,
40
+ title={Fast Segment Anything},
41
+ author={Xu Zhao and Wenchao Ding and Yongqi An and Yinglong Du and Tao Yu and Min Li and Ming Tang and Jinqiao Wang},
42
+ year={2023},
43
+ eprint={2306.12156},
44
+ archivePrefix={arXiv},
45
+ primaryClass={cs.CV}
46
+ }
47
+ ```
__pycache__/tools.cpython-39.pyc ADDED
Binary file (8.4 kB). View file
 
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import gradio as gr
3
+ import torch
4
+ from tools import fast_process, format_results, box_prompt, point_prompt
5
+ from PIL import ImageDraw
6
+ import numpy as np
7
+
8
+ # Load the pre-trained model
9
+ model = YOLO('checkpoints/FastSAM.pt')
10
+
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+
13
+ # Description
14
+ title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
15
+
16
+ news = """ # 📖 News
17
+
18
+ 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
19
+
20
+ 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
21
+
22
+ """
23
+
24
+ description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
25
+
26
+ 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
27
+
28
+ ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
29
+
30
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
31
+
32
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
33
+
34
+ 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
35
+
36
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
37
+
38
+ """
39
+
40
+ description_p = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
41
+
42
+ 🎯 Upload an Image, add points and segment it with Fast Segment Anything (Points mode).
43
+
44
+ ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
45
+
46
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
47
+
48
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
49
+
50
+ 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
51
+
52
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
53
+
54
+ """
55
+
56
+ examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"], ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
57
+ ["assets/sa_561.jpg"], ["assets/sa_192.jpg"], ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
58
+
59
+ default_example = examples[0]
60
+
61
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
62
+
63
+
64
+ def segment_everything(
65
+ input,
66
+ input_size=1024,
67
+ iou_threshold=0.7,
68
+ conf_threshold=0.25,
69
+ better_quality=False,
70
+ withContours=True,
71
+ use_retina=True,
72
+ mask_random_color=True,
73
+ ):
74
+ input_size = int(input_size) # 确保 imgsz 是整数
75
+
76
+ # Thanks for the suggestion by hysts in HuggingFace.
77
+ w, h = input.size
78
+ scale = input_size / max(w, h)
79
+ new_w = int(w * scale)
80
+ new_h = int(h * scale)
81
+ input = input.resize((new_w, new_h))
82
+
83
+ results = model(input,
84
+ device=device,
85
+ retina_masks=True,
86
+ iou=iou_threshold,
87
+ conf=conf_threshold,
88
+ imgsz=input_size,)
89
+
90
+ fig = fast_process(annotations=results[0].masks.data,
91
+ image=input,
92
+ device=device,
93
+ scale=(1024 // input_size),
94
+ better_quality=better_quality,
95
+ mask_random_color=mask_random_color,
96
+ bbox=None,
97
+ use_retina=use_retina,
98
+ withContours=withContours,)
99
+ return fig
100
+
101
+ def segment_with_points(
102
+ input,
103
+ input_size=1024,
104
+ iou_threshold=0.7,
105
+ conf_threshold=0.25,
106
+ better_quality=False,
107
+ withContours=True,
108
+ mask_random_color=True,
109
+ use_retina=True,
110
+ ):
111
+ global global_points
112
+ global global_point_label
113
+
114
+ input_size = int(input_size) # 确保 imgsz 是整数
115
+ # Thanks for the suggestion by hysts in HuggingFace.
116
+ w, h = input.size
117
+ scale = input_size / max(w, h)
118
+ new_w = int(w * scale)
119
+ new_h = int(h * scale)
120
+ input = input.resize((new_w, new_h))
121
+
122
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
123
+
124
+ results = model(input,
125
+ device=device,
126
+ retina_masks=True,
127
+ iou=iou_threshold,
128
+ conf=conf_threshold,
129
+ imgsz=input_size,)
130
+
131
+ results = format_results(results[0], 0)
132
+
133
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
134
+ annotations = np.array([annotations])
135
+
136
+ fig = fast_process(annotations=annotations,
137
+ image=input,
138
+ device=device,
139
+ scale=(1024 // input_size),
140
+ better_quality=better_quality,
141
+ mask_random_color=mask_random_color,
142
+ bbox=None,
143
+ use_retina=use_retina,
144
+ withContours=withContours,)
145
+ global_points = []
146
+ global_point_label = []
147
+ return fig, None
148
+
149
+ def get_points_with_draw(image, label, evt: gr.SelectData):
150
+ x, y = evt.index[0], evt.index[1]
151
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
152
+ global global_points
153
+ global global_point_label
154
+ print((x, y))
155
+ global_points.append([x, y])
156
+ global_point_label.append(1 if label == 'Add Mask' else 0)
157
+
158
+ # 创建一个可以在图像上绘图的对象
159
+ draw = ImageDraw.Draw(image)
160
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
161
+ return image
162
+
163
+
164
+ # input_size=1024
165
+ # high_quality_visual=True
166
+ # inp = 'assets/sa_192.jpg'
167
+ # input = Image.open(inp)
168
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
169
+ # input_size = int(input_size) # 确保 imgsz 是整数
170
+ # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
171
+ # pil_image = fast_process(annotations=results[0].masks.data,
172
+ # image=input, high_quality=high_quality_visual, device=device)
173
+
174
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
175
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
176
+
177
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
178
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
179
+
180
+ global_points = []
181
+ global_point_label = [] # TODO:Clear points each image
182
+
183
+ input_size_slider = gr.components.Slider(minimum=512,
184
+ maximum=1024,
185
+ value=1024,
186
+ step=64,
187
+ label='Input_size',
188
+ info='Our model was trained on a size of 1024')
189
+
190
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
191
+ with gr.Row():
192
+ with gr.Column(scale=1):
193
+ # Title
194
+ gr.Markdown(title)
195
+
196
+ with gr.Column(scale=1):
197
+ # News
198
+ gr.Markdown(news)
199
+
200
+ with gr.Tab("Everything mode"):
201
+ # Images
202
+ with gr.Row(variant="panel"):
203
+ with gr.Column(scale=1):
204
+ cond_img_e.render()
205
+
206
+ with gr.Column(scale=1):
207
+ segm_img_e.render()
208
+
209
+ # Submit & Clear
210
+ with gr.Row():
211
+ with gr.Column():
212
+ input_size_slider.render()
213
+
214
+ with gr.Row():
215
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
216
+
217
+ with gr.Column():
218
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
219
+ clear_btn_e = gr.Button("Clear", variant="secondary")
220
+
221
+ gr.Markdown("Try some of the examples below ⬇️")
222
+ gr.Examples(examples=examples,
223
+ inputs=[cond_img_e],
224
+ outputs=segm_img_e,
225
+ fn=segment_everything,
226
+ cache_examples=True,
227
+ examples_per_page=4)
228
+
229
+ with gr.Column():
230
+ with gr.Accordion("Advanced options", open=False):
231
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
232
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
233
+ with gr.Row():
234
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
235
+ with gr.Column():
236
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
237
+
238
+ # Description
239
+ gr.Markdown(description_e)
240
+
241
+ with gr.Tab("Points mode"):
242
+ # Images
243
+ with gr.Row(variant="panel"):
244
+ with gr.Column(scale=1):
245
+ cond_img_p.render()
246
+
247
+ with gr.Column(scale=1):
248
+ segm_img_p.render()
249
+
250
+ # Submit & Clear
251
+ with gr.Row():
252
+ with gr.Column():
253
+ with gr.Row():
254
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
255
+
256
+ with gr.Column():
257
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
258
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
259
+
260
+ gr.Markdown("Try some of the examples below ⬇️")
261
+ gr.Examples(examples=examples,
262
+ inputs=[cond_img_p],
263
+ outputs=segm_img_p,
264
+ fn=segment_with_points,
265
+ # cache_examples=True,
266
+ examples_per_page=4)
267
+
268
+ with gr.Column():
269
+ # Description
270
+ gr.Markdown(description_p)
271
+
272
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
273
+
274
+ segment_btn_e.click(segment_everything,
275
+ inputs=[cond_img_e, input_size_slider, iou_threshold, conf_threshold, mor_check, contour_check, retina_check],
276
+ outputs=segm_img_e)
277
+
278
+ segment_btn_p.click(segment_with_points,
279
+ inputs=[cond_img_p],
280
+ outputs=[segm_img_p, cond_img_p])
281
+
282
+ def clear():
283
+ return None, None
284
+
285
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
286
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
287
+
288
+ demo.queue()
289
+ demo.launch()
app_debug.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import gradio as gr
5
+ import cv2
6
+ import torch
7
+ from PIL import Image
8
+
9
+ # Load the pre-trained model
10
+ model = YOLO('checkpoints/FastSAM.pt')
11
+
12
+ # Description
13
+ title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
14
+
15
+ description = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
16
+
17
+ 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
18
+
19
+ ⌛️ It takes about 4~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
20
+
21
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
22
+
23
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
24
+
25
+ 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
26
+
27
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
28
+
29
+ """
30
+
31
+ examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"],
32
+ ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
33
+ ["assets/sa_561.jpg"], ["assets/sa_192.jpg"],
34
+ ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
35
+
36
+ default_example = examples[0]
37
+
38
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
39
+
40
+ def fast_process(annotations, image, high_quality, device, scale):
41
+ if isinstance(annotations[0],dict):
42
+ annotations = [annotation['segmentation'] for annotation in annotations]
43
+
44
+ original_h = image.height
45
+ original_w = image.width
46
+ if high_quality == True:
47
+ if isinstance(annotations[0],torch.Tensor):
48
+ annotations = np.array(annotations.cpu())
49
+ for i, mask in enumerate(annotations):
50
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
51
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
52
+ if device == 'cpu':
53
+ annotations = np.array(annotations)
54
+ inner_mask = fast_show_mask(annotations,
55
+ plt.gca(),
56
+ bbox=None,
57
+ points=None,
58
+ pointlabel=None,
59
+ retinamask=True,
60
+ target_height=original_h,
61
+ target_width=original_w)
62
+ else:
63
+ if isinstance(annotations[0],np.ndarray):
64
+ annotations = torch.from_numpy(annotations)
65
+ inner_mask = fast_show_mask_gpu(annotations,
66
+ plt.gca(),
67
+ bbox=None,
68
+ points=None,
69
+ pointlabel=None)
70
+ if isinstance(annotations, torch.Tensor):
71
+ annotations = annotations.cpu().numpy()
72
+
73
+ if high_quality:
74
+ contour_all = []
75
+ temp = np.zeros((original_h, original_w,1))
76
+ for i, mask in enumerate(annotations):
77
+ if type(mask) == dict:
78
+ mask = mask['segmentation']
79
+ annotation = mask.astype(np.uint8)
80
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
81
+ for contour in contours:
82
+ contour_all.append(contour)
83
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
84
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
85
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
86
+ image = image.convert('RGBA')
87
+
88
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
89
+ image.paste(overlay_inner, (0, 0), overlay_inner)
90
+
91
+ if high_quality:
92
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
93
+ image.paste(overlay_contour, (0, 0), overlay_contour)
94
+
95
+ return image
96
+
97
+ # CPU post process
98
+ def fast_show_mask(annotation, ax, bbox=None,
99
+ points=None, pointlabel=None,
100
+ retinamask=True, target_height=960,
101
+ target_width=960):
102
+ msak_sum = annotation.shape[0]
103
+ height = annotation.shape[1]
104
+ weight = annotation.shape[2]
105
+ # 将annotation 按照面积 排序
106
+ areas = np.sum(annotation, axis=(1, 2))
107
+ sorted_indices = np.argsort(areas)[::1]
108
+ annotation = annotation[sorted_indices]
109
+
110
+ index = (annotation != 0).argmax(axis=0)
111
+ color = np.random.random((msak_sum,1,1,3))
112
+ transparency = np.ones((msak_sum,1,1,1)) * 0.6
113
+ visual = np.concatenate([color,transparency],axis=-1)
114
+ mask_image = np.expand_dims(annotation,-1) * visual
115
+
116
+ mask = np.zeros((height,weight,4))
117
+
118
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
119
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
120
+ # 使用向量化索引更新show的值
121
+ mask[h_indices, w_indices, :] = mask_image[indices]
122
+ if bbox is not None:
123
+ x1, y1, x2, y2 = bbox
124
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
125
+ # draw point
126
+ if points is not None:
127
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
128
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
129
+
130
+ if retinamask==False:
131
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
132
+
133
+ return mask
134
+
135
+
136
+ def fast_show_mask_gpu(annotation, ax,
137
+ bbox=None, points=None,
138
+ pointlabel=None):
139
+ msak_sum = annotation.shape[0]
140
+ height = annotation.shape[1]
141
+ weight = annotation.shape[2]
142
+ areas = torch.sum(annotation, dim=(1, 2))
143
+ sorted_indices = torch.argsort(areas, descending=False)
144
+ annotation = annotation[sorted_indices]
145
+ # 找每个位置第一个非零值下标
146
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
147
+ color = torch.rand((msak_sum,1,1,3)).to(annotation.device)
148
+ transparency = torch.ones((msak_sum,1,1,1)).to(annotation.device) * 0.6
149
+ visual = torch.cat([color,transparency],dim=-1)
150
+ mask_image = torch.unsqueeze(annotation,-1) * visual
151
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
152
+ mask = torch.zeros((height,weight,4)).to(annotation.device)
153
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
154
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
155
+ # 使用向量化索引更新show的值
156
+ mask[h_indices, w_indices, :] = mask_image[indices]
157
+ mask_cpu = mask.cpu().numpy()
158
+ if bbox is not None:
159
+ x1, y1, x2, y2 = bbox
160
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
161
+ # draw point
162
+ if points is not None:
163
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
164
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
165
+ return mask_cpu
166
+
167
+
168
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
169
+
170
+ def segment_image(input, evt: gr.SelectData=None, input_size=1024, high_visual_quality=True, iou_threshold=0.7, conf_threshold=0.25):
171
+ point = (evt.index[0],evt.index[1])
172
+ input_size = int(input_size) # 确保 imgsz 是整数
173
+
174
+ # Thanks for the suggestion by hysts in HuggingFace.
175
+ w, h = input.size
176
+ scale = input_size / max(w, h)
177
+ new_w = int(w * scale)
178
+ new_h = int(h * scale)
179
+ input = input.resize((new_w, new_h))
180
+
181
+ results = model(input, device=device, retina_masks=True, iou=iou_threshold, conf=conf_threshold, imgsz=input_size)
182
+ fig = fast_process(annotations=results[0].masks.data,
183
+ image=input, high_quality=high_visual_quality,
184
+ device=device, scale=(1024 // input_size),
185
+ points=)
186
+ return fig
187
+
188
+ # input_size=1024
189
+ # high_quality_visual=True
190
+ # inp = 'assets/sa_192.jpg'
191
+ # input = Image.open(inp)
192
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
193
+ # input_size = int(input_size) # 确保 imgsz 是整数
194
+ # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
195
+ # pil_image = fast_process(annotations=results[0].masks.data,
196
+ # image=input, high_quality=high_quality_visual, device=device)
197
+
198
+ cond_img = gr.Image(label="Input", value=default_example[0], type='pil')
199
+
200
+ segm_img = gr.Image(label="Segmented Image", interactive=False, type='pil')
201
+
202
+ input_size_slider = gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='Input_size (Our model was trained on a size of 1024)')
203
+
204
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
205
+ with gr.Row():
206
+ # Title
207
+ gr.Markdown(title)
208
+ # # # Description
209
+ # # gr.Markdown(description)
210
+
211
+ # Images
212
+ with gr.Row(variant="panel"):
213
+ with gr.Column(scale=1):
214
+ cond_img.render()
215
+
216
+ with gr.Column(scale=1):
217
+ segm_img.render()
218
+
219
+ # Submit & Clear
220
+ with gr.Row():
221
+ with gr.Column():
222
+ input_size_slider.render()
223
+
224
+ with gr.Row():
225
+ vis_check = gr.Checkbox(value=True, label='high_visual_quality')
226
+
227
+ with gr.Column():
228
+ segment_btn = gr.Button("Segment Anything", variant='primary')
229
+
230
+ # with gr.Column():
231
+ # clear_btn = gr.Button("Clear", variant="primary")
232
+
233
+ gr.Markdown("Try some of the examples below ⬇️")
234
+ gr.Examples(examples=examples,
235
+ inputs=[cond_img],
236
+ outputs=segm_img,
237
+ fn=segment_image,
238
+ cache_examples=True,
239
+ examples_per_page=4)
240
+ # gr.Markdown("Try some of the examples below ⬇️")
241
+ # gr.Examples(examples=examples,
242
+ # inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
243
+ # outputs=output,
244
+ # fn=segment_image,
245
+ # examples_per_page=4)
246
+
247
+ with gr.Column():
248
+ with gr.Accordion("Advanced options", open=False):
249
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou_threshold')
250
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf_threshold')
251
+
252
+ # Description
253
+ gr.Markdown(description)
254
+
255
+ cond_img.select(segment_image, [], input_img)
256
+
257
+ segment_btn.click(segment_image,
258
+ inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
259
+ outputs=segm_img)
260
+
261
+ # def clear():
262
+ # return None, None
263
+
264
+ # clear_btn.click(fn=clear, inputs=None, outputs=None)
265
+
266
+ demo.queue()
267
+ demo.launch()
268
+
269
+ # app_interface = gr.Interface(fn=predict,
270
+ # inputs=[gr.Image(type='pil'),
271
+ # gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
272
+ # gr.components.Checkbox(value=True, label='high_visual_quality')],
273
+ # # outputs=['plot'],
274
+ # outputs=gr.Image(type='pil'),
275
+ # # examples=[["assets/sa_8776.jpg"]],
276
+ # # # ["assets/sa_1309.jpg", 1024]],
277
+ # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
278
+ # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
279
+ # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
280
+ # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
281
+ # cache_examples=True,
282
+ # title="Fast Segment Anything (Everything mode)"
283
+ # )
284
+
285
+
286
+ # app_interface.queue(concurrency_count=1, max_size=20)
287
+ # app_interface.launch()
assets/sa_10039.jpg ADDED

Git LFS Details

  • SHA256: 4a9735583a997fa08e5eb36b3ba8bf17a31771bb2aea71e6d51ab9824c1d141e
  • Pointer size: 131 Bytes
  • Size of remote file: 391 kB
assets/sa_11025.jpg ADDED

Git LFS Details

  • SHA256: b7edd63aa5121414bc29a760770606d09387561ff990c89f9b82c35803bd20aa
  • Pointer size: 131 Bytes
  • Size of remote file: 988 kB
assets/sa_1309.jpg ADDED

Git LFS Details

  • SHA256: b1012cbfd3ffe4ee0da940dc45961fbd1ce7546bea566f650514ec56d72b0460
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
assets/sa_192.jpg ADDED

Git LFS Details

  • SHA256: dcec4fce91382cbfeb2711fff3caeae183c23cb6d8a6c9e2ca0cd2e8eac39512
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
assets/sa_414.jpg ADDED

Git LFS Details

  • SHA256: 69dbead40b43e54d3bb80fb372c2e241b0f3ff2159d32525433a75153e067c65
  • Pointer size: 132 Bytes
  • Size of remote file: 2.23 MB
assets/sa_561.jpg ADDED

Git LFS Details

  • SHA256: 837d725885e427534623dcc7d82ea846fffea046877c94e2e9c5b027d593796b
  • Pointer size: 131 Bytes
  • Size of remote file: 822 kB
assets/sa_862.jpg ADDED

Git LFS Details

  • SHA256: 06efc970f0d95faa6e8c69ee73f2032627569dde1c28bc783faebdaefa5eb2a8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.56 MB
assets/sa_8776.jpg ADDED

Git LFS Details

  • SHA256: 7d71aea32d9f14122378a0707a4243de968d87b292a20a905351b5eacd924212
  • Pointer size: 131 Bytes
  • Size of remote file: 471 kB
checkpoints/FastSAM.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
3
+ size 144943063
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base-----------------------------------
2
+ matplotlib==3.2.2
3
+ numpy
4
+ opencv-python
5
+ # Pillow>=7.1.2
6
+ # PyYAML>=5.3.1
7
+ # requests>=2.23.0
8
+ # scipy>=1.4.1
9
+ # torch
10
+ # torchvision
11
+ # tqdm>=4.64.0
12
+
13
+ # pandas>=1.1.4
14
+ # seaborn>=0.11.0
15
+
16
+ # Ultralytics-----------------------------------
17
+ ultralytics==8.0.121
18
+
tools.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+ # import clip
7
+
8
+
9
+ def convert_box_xywh_to_xyxy(box):
10
+ x1 = box[0]
11
+ y1 = box[1]
12
+ x2 = box[0] + box[2]
13
+ y2 = box[1] + box[3]
14
+ return [x1, y1, x2, y2]
15
+
16
+
17
+ def segment_image(image, bbox):
18
+ image_array = np.array(image)
19
+ segmented_image_array = np.zeros_like(image_array)
20
+ x1, y1, x2, y2 = bbox
21
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
22
+ segmented_image = Image.fromarray(segmented_image_array)
23
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
24
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
25
+ transparency_mask = np.zeros(
26
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
27
+ )
28
+ transparency_mask[y1:y2, x1:x2] = 255
29
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
30
+ black_image.paste(segmented_image, mask=transparency_mask_image)
31
+ return black_image
32
+
33
+
34
+ def format_results(result, filter=0):
35
+ annotations = []
36
+ n = len(result.masks.data)
37
+ for i in range(n):
38
+ annotation = {}
39
+ mask = result.masks.data[i] == 1.0
40
+
41
+ if torch.sum(mask) < filter:
42
+ continue
43
+ annotation["id"] = i
44
+ annotation["segmentation"] = mask.cpu().numpy()
45
+ annotation["bbox"] = result.boxes.data[i]
46
+ annotation["score"] = result.boxes.conf[i]
47
+ annotation["area"] = annotation["segmentation"].sum()
48
+ annotations.append(annotation)
49
+ return annotations
50
+
51
+
52
+ def filter_masks(annotations): # filte the overlap mask
53
+ annotations.sort(key=lambda x: x["area"], reverse=True)
54
+ to_remove = set()
55
+ for i in range(0, len(annotations)):
56
+ a = annotations[i]
57
+ for j in range(i + 1, len(annotations)):
58
+ b = annotations[j]
59
+ if i != j and j not in to_remove:
60
+ # check if
61
+ if b["area"] < a["area"]:
62
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
63
+ "segmentation"
64
+ ].sum() > 0.8:
65
+ to_remove.add(j)
66
+
67
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
68
+
69
+
70
+ def get_bbox_from_mask(mask):
71
+ mask = mask.astype(np.uint8)
72
+ contours, hierarchy = cv2.findContours(
73
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
74
+ )
75
+ x1, y1, w, h = cv2.boundingRect(contours[0])
76
+ x2, y2 = x1 + w, y1 + h
77
+ if len(contours) > 1:
78
+ for b in contours:
79
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
80
+ # 将多个bbox合并成一个
81
+ x1 = min(x1, x_t)
82
+ y1 = min(y1, y_t)
83
+ x2 = max(x2, x_t + w_t)
84
+ y2 = max(y2, y_t + h_t)
85
+ h = y2 - y1
86
+ w = x2 - x1
87
+ return [x1, y1, x2, y2]
88
+
89
+ def fast_process(
90
+ annotations,
91
+ image,
92
+ device,
93
+ scale,
94
+ better_quality=False,
95
+ mask_random_color=True,
96
+ bbox=None,
97
+ use_retina=True,
98
+ withContours=True,
99
+ ):
100
+ if isinstance(annotations[0], dict):
101
+ annotations = [annotation['segmentation'] for annotation in annotations]
102
+
103
+ original_h = image.height
104
+ original_w = image.width
105
+ if better_quality:
106
+ if isinstance(annotations[0], torch.Tensor):
107
+ annotations = np.array(annotations.cpu())
108
+ for i, mask in enumerate(annotations):
109
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
110
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
111
+ if device == 'cpu':
112
+ annotations = np.array(annotations)
113
+ inner_mask = fast_show_mask(
114
+ annotations,
115
+ plt.gca(),
116
+ random_color=mask_random_color,
117
+ bbox=bbox,
118
+ retinamask=use_retina,
119
+ target_height=original_h,
120
+ target_width=original_w,
121
+ )
122
+ else:
123
+ if isinstance(annotations[0], np.ndarray):
124
+ annotations = torch.from_numpy(annotations)
125
+ inner_mask = fast_show_mask_gpu(
126
+ annotations,
127
+ plt.gca(),
128
+ random_color=mask_random_color,
129
+ bbox=bbox,
130
+ retinamask=use_retina,
131
+ target_height=original_h,
132
+ target_width=original_w,
133
+ )
134
+ if isinstance(annotations, torch.Tensor):
135
+ annotations = annotations.cpu().numpy()
136
+
137
+ if withContours:
138
+ contour_all = []
139
+ temp = np.zeros((original_h, original_w, 1))
140
+ for i, mask in enumerate(annotations):
141
+ if type(mask) == dict:
142
+ mask = mask['segmentation']
143
+ annotation = mask.astype(np.uint8)
144
+ if use_retina == False:
145
+ annotation = cv2.resize(
146
+ annotation,
147
+ (original_w, original_h),
148
+ interpolation=cv2.INTER_NEAREST,
149
+ )
150
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
151
+ for contour in contours:
152
+ contour_all.append(contour)
153
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
154
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
155
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
156
+
157
+ image = image.convert('RGBA')
158
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
159
+ image.paste(overlay_inner, (0, 0), overlay_inner)
160
+
161
+ if withContours:
162
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
163
+ image.paste(overlay_contour, (0, 0), overlay_contour)
164
+
165
+ return image
166
+
167
+
168
+ # CPU post process
169
+ def fast_show_mask(
170
+ annotation,
171
+ ax,
172
+ random_color=False,
173
+ bbox=None,
174
+ retinamask=True,
175
+ target_height=960,
176
+ target_width=960,
177
+ ):
178
+ mask_sum = annotation.shape[0]
179
+ height = annotation.shape[1]
180
+ weight = annotation.shape[2]
181
+ # 将annotation 按照面积 排序
182
+ areas = np.sum(annotation, axis=(1, 2))
183
+ sorted_indices = np.argsort(areas)[::1]
184
+ annotation = annotation[sorted_indices]
185
+
186
+ index = (annotation != 0).argmax(axis=0)
187
+ if random_color == True:
188
+ color = np.random.random((mask_sum, 1, 1, 3))
189
+ else:
190
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
191
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
192
+ visual = np.concatenate([color, transparency], axis=-1)
193
+ mask_image = np.expand_dims(annotation, -1) * visual
194
+
195
+ mask = np.zeros((height, weight, 4))
196
+
197
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
198
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
199
+
200
+ mask[h_indices, w_indices, :] = mask_image[indices]
201
+ if bbox is not None:
202
+ x1, y1, x2, y2 = bbox
203
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
204
+
205
+ if retinamask == False:
206
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
207
+
208
+ return mask
209
+
210
+
211
+ def fast_show_mask_gpu(
212
+ annotation,
213
+ ax,
214
+ random_color=False,
215
+ bbox=None,
216
+ retinamask=True,
217
+ target_height=960,
218
+ target_width=960,
219
+ ):
220
+ device = annotation.device
221
+ mask_sum = annotation.shape[0]
222
+ height = annotation.shape[1]
223
+ weight = annotation.shape[2]
224
+ areas = torch.sum(annotation, dim=(1, 2))
225
+ sorted_indices = torch.argsort(areas, descending=False)
226
+ annotation = annotation[sorted_indices]
227
+ # 找每个位置第一个非零值下标
228
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
229
+ if random_color == True:
230
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
231
+ else:
232
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
233
+ [30 / 255, 144 / 255, 255 / 255]
234
+ ).to(device)
235
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
236
+ visual = torch.cat([color, transparency], dim=-1)
237
+ mask_image = torch.unsqueeze(annotation, -1) * visual
238
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
239
+ mask = torch.zeros((height, weight, 4)).to(device)
240
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
241
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
242
+ # 使用向量化索引更新show的值
243
+ mask[h_indices, w_indices, :] = mask_image[indices]
244
+ mask_cpu = mask.cpu().numpy()
245
+ if bbox is not None:
246
+ x1, y1, x2, y2 = bbox
247
+ ax.add_patch(
248
+ plt.Rectangle(
249
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
250
+ )
251
+ )
252
+ if retinamask == False:
253
+ mask_cpu = cv2.resize(
254
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
255
+ )
256
+ return mask_cpu
257
+
258
+
259
+ # # clip
260
+ # @torch.no_grad()
261
+ # def retriev(
262
+ # model, preprocess, elements, search_text: str, device
263
+ # ) -> int:
264
+ # preprocessed_images = [preprocess(image).to(device) for image in elements]
265
+ # tokenized_text = clip.tokenize([search_text]).to(device)
266
+ # stacked_images = torch.stack(preprocessed_images)
267
+ # image_features = model.encode_image(stacked_images)
268
+ # text_features = model.encode_text(tokenized_text)
269
+ # image_features /= image_features.norm(dim=-1, keepdim=True)
270
+ # text_features /= text_features.norm(dim=-1, keepdim=True)
271
+ # probs = 100.0 * image_features @ text_features.T
272
+ # return probs[:, 0].softmax(dim=0)
273
+
274
+
275
+ def crop_image(annotations, image_path):
276
+ image = Image.open(image_path)
277
+ ori_w, ori_h = image.size
278
+ mask_h, mask_w = annotations[0]["segmentation"].shape
279
+ if ori_w != mask_w or ori_h != mask_h:
280
+ image = image.resize((mask_w, mask_h))
281
+ cropped_boxes = []
282
+ cropped_images = []
283
+ not_crop = []
284
+ filter_id = []
285
+ # annotations, _ = filter_masks(annotations)
286
+ # filter_id = list(_)
287
+ for _, mask in enumerate(annotations):
288
+ if np.sum(mask["segmentation"]) <= 100:
289
+ filter_id.append(_)
290
+ continue
291
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
292
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
293
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
294
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
295
+
296
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
297
+
298
+
299
+ def box_prompt(masks, bbox, target_height, target_width):
300
+ h = masks.shape[1]
301
+ w = masks.shape[2]
302
+ if h != target_height or w != target_width:
303
+ bbox = [
304
+ int(bbox[0] * w / target_width),
305
+ int(bbox[1] * h / target_height),
306
+ int(bbox[2] * w / target_width),
307
+ int(bbox[3] * h / target_height),
308
+ ]
309
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
310
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
311
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
312
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
313
+
314
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
315
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
316
+
317
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
318
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
319
+
320
+ union = bbox_area + orig_masks_area - masks_area
321
+ IoUs = masks_area / union
322
+ max_iou_index = torch.argmax(IoUs)
323
+
324
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
325
+
326
+
327
+ def point_prompt(masks, points, pointlabel, target_height, target_width): # numpy 处理
328
+ h = masks[0]["segmentation"].shape[0]
329
+ w = masks[0]["segmentation"].shape[1]
330
+ if h != target_height or w != target_width:
331
+ points = [
332
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
333
+ for point in points
334
+ ]
335
+ onemask = np.zeros((h, w))
336
+ for i, annotation in enumerate(masks):
337
+ if type(annotation) == dict:
338
+ mask = annotation["segmentation"]
339
+ else:
340
+ mask = annotation
341
+ for i, point in enumerate(points):
342
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
343
+ onemask += mask
344
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
345
+ onemask -= mask
346
+ onemask = onemask >= 1
347
+ return onemask, 0
348
+
349
+
350
+ # def text_prompt(annotations, args):
351
+ # cropped_boxes, cropped_images, not_crop, filter_id, annotaions = crop_image(
352
+ # annotations, args.img_path
353
+ # )
354
+ # clip_model, preprocess = clip.load("ViT-B/32", device=args.device)
355
+ # scores = retriev(
356
+ # clip_model, preprocess, cropped_boxes, args.text_prompt, device=args.device
357
+ # )
358
+ # max_idx = scores.argsort()
359
+ # max_idx = max_idx[-1]
360
+ # max_idx += sum(np.array(filter_id) <= int(max_idx))
361
+ # return annotaions[max_idx]["segmentation"], max_idx