sentencebird commited on
Commit
e82d36a
·
1 Parent(s): 66115cd

add: torchvisionのsegmentationで背景マスキング

Browse files
Files changed (6) hide show
  1. app.py +28 -49
  2. cv_funcs.py +67 -0
  3. favicon.jpeg +0 -0
  4. icon.jpeg +0 -0
  5. requirements.txt +3 -1
  6. torchvision_funcs.py +41 -0
app.py CHANGED
@@ -6,83 +6,62 @@ import random
6
  import time
7
  import seaborn as sns
8
 
9
- def get_concat_h(im1, im2):
10
- dst = Image.new('RGB', (im1.width + im2.width, im1.height))
11
- dst.paste(im1, (0, 0))
12
- dst.paste(im2, (im1.width, 0))
13
- return dst
14
-
15
- def get_concat_v(im1, im2):
16
- dst = Image.new('RGB', (im1.width, im1.height + im2.height))
17
- dst.paste(im1, (0, 0))
18
- dst.paste(im2, (0, im1.height))
19
- return dst
20
-
21
- def hsv_to_rgb(h, s, v):
22
- bgr = cv2.cvtColor(np.array([[[h, s, v]]], dtype=np.uint8), cv2.COLOR_HSV2BGR)[0][0]
23
- return [bgr[2]/255, bgr[1]/255, bgr[0]/255]
24
 
25
  @st.cache
26
  def show_generated_image(image):
27
  st.image(image)
28
-
29
  @st.cache(suppress_st_warning=True)
30
- def randomize_palette_colors(n_rows, n_cols, palette1="Set1", palette2="Set2", seed=0):
31
  random.seed(seed)
32
- colors1 = sns.color_palette(palette1, n_rows*n_cols)
33
- colors2 = sns.color_palette(palette2, n_rows*n_cols)
34
- colors1, colors2 = random.sample(colors1, len(colors1)), random.sample(colors2, len(colors2))
35
- return colors1, colors2
36
 
37
  @st.cache(suppress_st_warning=True)
38
- def randomize_rgb_colors(n_rows, n_cols, seed=0):
39
- random.seed(seed)
40
- colors1 = [[random.random() for j in range(3)] for i in range(n_rows*n_cols)]
41
- colors2 = [[random.random() for j in range(3)] for i in range(n_rows*n_cols)]
42
- return colors1, colors2
43
-
44
- @st.cache(suppress_st_warning=True)
45
- def randomize_hsv_colors(n_rows, n_cols, s=255, v=255, seed=0):
46
- random.seed(seed)
47
- colors1 = [hsv_to_rgb(random.random()*180, s, v) for i in range(n_rows*n_cols)]
48
- colors2 = [hsv_to_rgb(random.random()*180, s, v) for i in range(n_rows*n_cols)]
49
- return colors1, colors2
50
 
51
  title = 'Andy Warhol like Image Generator'
52
- st.set_page_config(page_title=title, layout='centered')
53
  st.title(title)
54
  uploaded_file = st.file_uploader('Choose an image file')
55
  if uploaded_file is None: uploaded_file = './sample.jpg'
56
 
57
  if uploaded_file is not None:
58
  im = Image.open(uploaded_file)
59
- im.thumbnail((1000, 1000),resample=Image.BICUBIC) # resize
60
- st.image(im, caption='Original')
61
-
62
- im_gray = np.array(im.convert('L'))
 
 
 
 
 
63
  thresh, _img = cv2.threshold(im_gray, 0, 255, cv2.THRESH_OTSU)
64
 
65
  n_rows, n_cols = st.number_input('Rows', value=3), st.number_input('Columns', value=3)
66
-
67
- # s = st.slider('Saturation', value=125.0, min_value=0.0, max_value=255.0)
68
- # v = st.slider('Brightness', value=255.0, min_value=0.0, max_value=255.0)
69
- colors1, colors2 = randomize_palette_colors(n_rows, n_cols)
70
- thresh = st.slider('Threshold', value=thresh, min_value=0.0, max_value=255.0)
71
 
72
  if st.button('Shuffle colors'):
73
- colors1, colors2 = randomize_palette_colors(n_rows, n_cols, seed=time.time())
74
-
75
  if True or st.button('Generate'):
76
- im_bool = im_gray > thresh
77
-
78
  ims_generated = []
 
79
  for row in range(n_rows):
80
  for col in range(n_cols):
81
  i_color = n_cols * row + col
82
- rgb1, rgb2 = np.array(colors1[i_color])*np.array([255, 255, 255]).tolist(), np.array(colors2[i_color])*np.array([255, 255, 255]).tolist()
83
  ims_col = np.empty((*im_gray.shape, 3))
84
  for i in range(3): # RGB
85
- ims_col[:, :, i] = (im_gray > thresh) * rgb1[i] + (im_gray <= thresh) * rgb2[i]
 
86
  if col == 0:
87
  im_col_concat = Image.fromarray(ims_col.astype(np.uint8))
88
  else:
 
6
  import time
7
  import seaborn as sns
8
 
9
+ from cv_funcs import *
10
+ from torchvision_funcs import *
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  @st.cache
13
  def show_generated_image(image):
14
  st.image(image)
15
+
16
  @st.cache(suppress_st_warning=True)
17
+ def randomize_palette_colors(n_rows, n_cols, palettes=['Set1', 'Set3', 'Spectral'], seed=time.time(), n_times=10):
18
  random.seed(seed)
19
+ colors = [sns.color_palette(palette, n_rows*n_cols*n_times) for palette in palettes]
20
+ _ = [random.shuffle(color) for color in colors]
21
+ return colors
 
22
 
23
  @st.cache(suppress_st_warning=True)
24
+ def remove_image_background(image):
25
+ return deeplabv3_remove_bg(image)
 
 
 
 
 
 
 
 
 
 
26
 
27
  title = 'Andy Warhol like Image Generator'
28
+ st.set_page_config(page_title=title, page_icon='favicon.jpeg', layout='centered')
29
  st.title(title)
30
  uploaded_file = st.file_uploader('Choose an image file')
31
  if uploaded_file is None: uploaded_file = './sample.jpg'
32
 
33
  if uploaded_file is not None:
34
  im = Image.open(uploaded_file)
35
+ im.thumbnail((1000, 1000),resample=Image.BICUBIC) # resize
36
+
37
+ is_masked = st.checkbox('With background masking? (3 colors)')
38
+ if is_masked:
39
+ im_masked, index_masked = remove_image_background(im)
40
+ st.image(im_masked, caption='Masked image')
41
+ else: st.image(im, caption='Original')
42
+
43
+ im_gray = np.array(im.convert('L'))
44
  thresh, _img = cv2.threshold(im_gray, 0, 255, cv2.THRESH_OTSU)
45
 
46
  n_rows, n_cols = st.number_input('Rows', value=3), st.number_input('Columns', value=3)
47
+
48
+ thresh = st.slider('Threshold', value=thresh, min_value=0.0, max_value=255.0)
49
+ colors = randomize_palette_colors(n_rows, n_cols, seed=0)
 
 
50
 
51
  if st.button('Shuffle colors'):
52
+ colors = randomize_palette_colors(n_rows, n_cols, seed=time.time())
53
+
54
  if True or st.button('Generate'):
 
 
55
  ims_generated = []
56
+
57
  for row in range(n_rows):
58
  for col in range(n_cols):
59
  i_color = n_cols * row + col
60
+ rgbs = [np.array(color[i_color])*np.array([255, 255, 255]).tolist() for color in colors]
61
  ims_col = np.empty((*im_gray.shape, 3))
62
  for i in range(3): # RGB
63
+ ims_col[:, :, i] = (im_gray <= thresh) * rgbs[0][i] + (im_gray > thresh) * rgbs[1][i]
64
+ if is_masked: ims_col[:, :, i][index_masked] = rgbs[2][i]
65
  if col == 0:
66
  im_col_concat = Image.fromarray(ims_col.astype(np.uint8))
67
  else:
cv_funcs.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ def get_concat_h(im1, im2):
6
+ dst = Image.new('RGB', (im1.width + im2.width, im1.height))
7
+ dst.paste(im1, (0, 0))
8
+ dst.paste(im2, (im1.width, 0))
9
+ return dst
10
+
11
+ def get_concat_v(im1, im2):
12
+ dst = Image.new('RGB', (im1.width, im1.height + im2.height))
13
+ dst.paste(im1, (0, 0))
14
+ dst.paste(im2, (0, im1.height))
15
+ return dst
16
+
17
+ def hsv_to_rgb(h, s, v):
18
+ bgr = cv2.cvtColor(np.array([[[h, s, v]]], dtype=np.uint8), cv2.COLOR_HSV2BGR)[0][0]
19
+ return [bgr[2]/255, bgr[1]/255, bgr[0]/255]
20
+
21
+ # def remove_bg(
22
+ # path,
23
+ # BLUR = 21,
24
+ # CANNY_THRESH_1 = 10,
25
+ # CANNY_THRESH_2 = 200,
26
+ # MASK_DILATE_ITER = 10,
27
+ # MASK_ERODE_ITER = 10,
28
+ # MASK_COLOR = (0.0,0.0,1.0),
29
+ # ):
30
+ # img = cv2.imread(path)
31
+ # gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
32
+
33
+ # edges = cv2.Canny(gray, CANNY_THRESH_1, CANNY_THRESH_2)
34
+ # edges = cv2.dilate(edges, None)
35
+ # edges = cv2.erode(edges, None)
36
+
37
+ # contour_info = []
38
+ # contours, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
39
+ # for c in contours:
40
+ # contour_info.append((
41
+ # c,
42
+ # cv2.isContourConvex(c),
43
+ # cv2.contourArea(c),
44
+ # ))
45
+ # contour_info = sorted(contour_info, key=lambda c: c[2], reverse=True)
46
+ # max_contour = contour_info[0]
47
+
48
+ # mask = np.zeros(edges.shape)
49
+ # cv2.fillConvexPoly(mask, max_contour[0], (255))
50
+
51
+ # mask = cv2.dilate(mask, None, iterations=MASK_DILATE_ITER)
52
+ # mask = cv2.erode(mask, None, iterations=MASK_ERODE_ITER)
53
+ # mask = cv2.GaussianBlur(mask, (BLUR, BLUR), 0)
54
+ # mask_stack = np.dstack([mask]*3) # Create 3-channel alpha mask
55
+
56
+ # mask_stack = mask_stack.astype('float32') / 255.0 # Use float matrices,
57
+ # img = img.astype('float32') / 255.0 # for easy blending
58
+
59
+ # masked = (mask_stack * img) + ((1-mask_stack) * MASK_COLOR) # Blend
60
+ # masked = (masked * 255).astype('uint8') # Convert back to 8-bit
61
+
62
+ # c_blue, c_green, c_red = cv2.split(img)
63
+
64
+ # img_a = cv2.merge((c_red, c_green, c_blue, mask.astype('float32') / 255.0))
65
+ # index = np.where(img_a[:, :, 3] == 0)
66
+ # #img_a[index] = [1.0, 1.0, 1.0, 1.0]
67
+ # return img_a
favicon.jpeg ADDED
icon.jpeg DELETED
Binary file (254 kB)
 
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  streamlit==0.76.0
2
  Pillow
3
  opencv-python
4
- seaborn
 
 
 
1
  streamlit==0.76.0
2
  Pillow
3
  opencv-python
4
+ seaborn
5
+ http://download.pytorch.org/whl/cpu/torch-1.4.0%2Bcpu-cp36-cp36m-linux_x86_64.whl
6
+ torchvision==0.5.0
torchvision_funcs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from PIL import Image
4
+
5
+ import torch
6
+ import torchvision
7
+ from torchvision import transforms
8
+
9
+ def deeplabv3_remove_bg(img):
10
+ img = np.array(img, dtype=np.uint8)
11
+ # img = cv2.imread(image_path)
12
+ # img = img[...,::-1] #BGR->RGB
13
+ h,w,_ = img.shape
14
+ # img = cv2.resize(img,(1000,1000))
15
+
16
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
+
18
+ model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
19
+ model = model.to(device)
20
+ model.eval();
21
+
22
+ preprocess = transforms.Compose([
23
+ transforms.ToTensor(),
24
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
25
+ ])
26
+ input_tensor = preprocess(img)
27
+ input_batch = input_tensor.unsqueeze(0).to(device)
28
+
29
+ with torch.no_grad():
30
+ output = model(input_batch)['out'][0]
31
+ output = output.argmax(0)
32
+ mask = output.byte().cpu().numpy()
33
+ # mask = cv2.resize(mask,(w,h))
34
+ # img = cv2.resize(img,(w,h))
35
+ mask[mask>0] = 1.0 # NOTE: なぜか3が入っていたので
36
+ mask = np.dstack([mask, mask, mask])
37
+ img_masked = Image.fromarray(cv2.multiply(img, mask))
38
+ index_masked = np.where(np.array(mask)[:,:,2]==0)
39
+ return img_masked, index_masked
40
+
41
+