T.Masuda commited on
Commit
d65ec94
·
1 Parent(s): b7e0b23

clip-image

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoint/sam_vit_h_4b8939.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from meta_segment_anything import SegmentAnything
4
+ from PIL import Image, ImageDraw
5
+
6
+ def check_location(image, enable1, left1, top1, enable2, left2, top2, enable3, left3, top3):
7
+ if image is None:
8
+ yield None
9
+ return
10
+ if not enable1 and not enable2 and not enable3:
11
+ yield None
12
+ return
13
+
14
+ points = []
15
+ if enable1:
16
+ points.append([left1, top1])
17
+ if enable2:
18
+ points.append([left2, top2])
19
+ if enable3:
20
+ points.append([left3, top3])
21
+ for point in points:
22
+ left, top = point
23
+ draw = ImageDraw.Draw(image)
24
+ draw.ellipse([(left - 2, top - 2), (left + 3, top + 3)], fill=(255, 0, 0))
25
+ yield image
26
+
27
+ def process_image(image, enable1, left1, top1, enable2, left2, top2, enable3, left3, top3):
28
+ if image is None:
29
+ yield None
30
+ return
31
+ if not enable1 and not enable2 and not enable3:
32
+ yield None
33
+ return
34
+
35
+ predictor = SegmentAnything()
36
+ points = []
37
+ if enable1:
38
+ points.append([left1, top1])
39
+ if enable2:
40
+ points.append([left2, top2])
41
+ if enable3:
42
+ points.append([left3, top3])
43
+ newImage = Image.new('RGBA', image.size)
44
+ for point in points:
45
+ point_coords = np.array([[0, 0], point])
46
+ point_labels = np.array([0, 1])
47
+ masks, _, _ = predictor.predict(image, point_coords, point_labels)
48
+ index = 0
49
+ for mask in masks:
50
+ maskimage = SegmentAnything.makeMaskImage(mask.T, (0xff, 0xff, 0xff, 0xff))
51
+ index += 1
52
+ maskNewImage = SegmentAnything.makeNewImage(image, maskimage)
53
+ newImage.paste(maskNewImage, (0, 0), maskNewImage)
54
+ yield newImage
55
+
56
+ def tab_select(evt: gr.SelectData, state):
57
+ if evt.target.label == 'point2':
58
+ state['active'] = 1
59
+ elif evt.target.label == 'point3':
60
+ state['active'] = 2
61
+ else:
62
+ state['active'] = 0
63
+ return state
64
+
65
+ def image_select(evt: gr.SelectData, state, enable1, left1, top1, enable2, left2, top2, enable3, left3, top3):
66
+ if state['active'] == 2:
67
+ return [enable1, left1, top1, enable2, left2, top2, True, evt.index[0], evt.index[1]]
68
+ elif state['active'] == 1:
69
+ return [enable1, left1, top1, True, evt.index[0], evt.index[1], enable3, left3, top3]
70
+ return [True, evt.index[0], evt.index[1], enable2, left2, top2, enable3, left3, top3]
71
+
72
+ with gr.Blocks(title='clip-image') as app:
73
+ state = gr.State({ 'active': 0 })
74
+
75
+ gr.Markdown('''
76
+ # Clip Image
77
+ clip an image from given points
78
+ ''')
79
+ with gr.Row():
80
+ with gr.Column():
81
+ image = gr.Image(type='pil')
82
+ gr.Markdown('click on the image to position')
83
+ with gr.Tab("point1") as tab1:
84
+ enable1 = gr.Checkbox(label='enable', value=True)
85
+ left1 = gr.Slider(maximum=4000, step=1, label='left')
86
+ top1 = gr.Slider(maximum=4000, step=1, label='top')
87
+ with gr.Tab("point2") as tab2:
88
+ enable2 = gr.Checkbox(label='enable')
89
+ left2 = gr.Slider(maximum=4000, step=1, label='left')
90
+ top2 = gr.Slider(maximum=4000, step=1, label='top')
91
+ with gr.Tab("point3") as tab3:
92
+ enable3 = gr.Checkbox(label='enable')
93
+ left3 = gr.Slider(maximum=4000, step=1, label='left')
94
+ top3 = gr.Slider(maximum=4000, step=1, label='top')
95
+ btnloc = gr.Button(value='check location')
96
+ with gr.Row():
97
+ with gr.Column(min_width=160):
98
+ clearBtn = gr.ClearButton()
99
+ with gr.Column(min_width=160):
100
+ btn = gr.Button(value='Submit')
101
+ inputs = [image, enable1, left1, top1, enable2, left2, top2, enable3, left3, top3]
102
+ with gr.Column():
103
+ outputs = [gr.Image(label='segmentation', type='pil')]
104
+ tab1.select(tab_select, inputs=state, outputs=state)
105
+ tab2.select(tab_select, inputs=state, outputs=state)
106
+ tab3.select(tab_select, inputs=state, outputs=state)
107
+ image.select(image_select, inputs=[state, enable1, left1, top1, enable2, left2, top2, enable3, left3, top3], outputs=[enable1, left1, top1, enable2, left2, top2, enable3, left3, top3])
108
+ btnloc.click(check_location, inputs=inputs, outputs=outputs)
109
+ clearBtn.add(inputs + outputs)
110
+ btn.click(process_image, inputs=inputs, outputs=outputs)
111
+
112
+ gr.Examples(
113
+ [['examples/example1.jpg', True, 200, 250, True, 340, 250, False, 0, 0], ['examples/example2.jpg', True, 256, 256, False, 0, 0, False, 0, 0]],
114
+ inputs,
115
+ outputs,
116
+ process_image,
117
+ #cache_examples=True,
118
+ )
119
+
120
+ app.queue(concurrency_count=5)
121
+ app.launch()
checkpoint/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
examples/example1.jpg ADDED
examples/example2.jpg ADDED
meta_segment_anything.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ class SegmentAnything:
7
+ def __init__(self):
8
+ sam_checkpoint = 'checkpoint/sam_vit_h_4b8939.pth'
9
+ model_type = 'vit_h'
10
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
11
+ if torch.cuda.is_available():
12
+ sam.to(device='cuda')
13
+ self.sam = sam
14
+
15
+ def predict(self, image, point_coords, point_labels, box=None):
16
+ predictor = SamPredictor(self.sam)
17
+ predictor.set_image(np.array(image, dtype=np.uint8))
18
+ return predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box)
19
+
20
+ def generate(self, image):
21
+ mask_generator = SamAutomaticMaskGenerator(self.sam)
22
+ return mask_generator.generate(np.array(image, dtype=np.uint8))
23
+
24
+ @staticmethod
25
+ def makeMaskImage(mask, color):
26
+ image = Image.new('RGBA', mask.shape)
27
+ width, height = image.size
28
+ for x in range(width):
29
+ for y in range(height):
30
+ if mask[x, y]:
31
+ image.putpixel((x, y), color)
32
+ return image
33
+
34
+ @staticmethod
35
+ def makeNewImage(image, maskImage):
36
+ newImage = Image.new('RGBA', image.size)
37
+ timage = maskImage.copy()
38
+ width, height = timage.size
39
+ for x in range(width):
40
+ for y in range(height):
41
+ _, _, _, a = timage.getpixel((x, y))
42
+ timage.putpixel((x, y), (0, 0, 0, 255) if a > 0 else (0, 0, 0, 0))
43
+ newImage.paste(image, (0, 0), timage)
44
+ return newImage
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ git+https://github.com/facebookresearch/segment-anything.git