sentencebird
commited on
Commit
·
e82d36a
1
Parent(s):
66115cd
add: torchvisionのsegmentationで背景マスキング
Browse files- app.py +28 -49
- cv_funcs.py +67 -0
- favicon.jpeg +0 -0
- icon.jpeg +0 -0
- requirements.txt +3 -1
- torchvision_funcs.py +41 -0
app.py
CHANGED
@@ -6,83 +6,62 @@ import random
|
|
6 |
import time
|
7 |
import seaborn as sns
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
dst.paste(im1, (0, 0))
|
12 |
-
dst.paste(im2, (im1.width, 0))
|
13 |
-
return dst
|
14 |
-
|
15 |
-
def get_concat_v(im1, im2):
|
16 |
-
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
|
17 |
-
dst.paste(im1, (0, 0))
|
18 |
-
dst.paste(im2, (0, im1.height))
|
19 |
-
return dst
|
20 |
-
|
21 |
-
def hsv_to_rgb(h, s, v):
|
22 |
-
bgr = cv2.cvtColor(np.array([[[h, s, v]]], dtype=np.uint8), cv2.COLOR_HSV2BGR)[0][0]
|
23 |
-
return [bgr[2]/255, bgr[1]/255, bgr[0]/255]
|
24 |
|
25 |
@st.cache
|
26 |
def show_generated_image(image):
|
27 |
st.image(image)
|
28 |
-
|
29 |
@st.cache(suppress_st_warning=True)
|
30 |
-
def randomize_palette_colors(n_rows, n_cols,
|
31 |
random.seed(seed)
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
return colors1, colors2
|
36 |
|
37 |
@st.cache(suppress_st_warning=True)
|
38 |
-
def
|
39 |
-
|
40 |
-
colors1 = [[random.random() for j in range(3)] for i in range(n_rows*n_cols)]
|
41 |
-
colors2 = [[random.random() for j in range(3)] for i in range(n_rows*n_cols)]
|
42 |
-
return colors1, colors2
|
43 |
-
|
44 |
-
@st.cache(suppress_st_warning=True)
|
45 |
-
def randomize_hsv_colors(n_rows, n_cols, s=255, v=255, seed=0):
|
46 |
-
random.seed(seed)
|
47 |
-
colors1 = [hsv_to_rgb(random.random()*180, s, v) for i in range(n_rows*n_cols)]
|
48 |
-
colors2 = [hsv_to_rgb(random.random()*180, s, v) for i in range(n_rows*n_cols)]
|
49 |
-
return colors1, colors2
|
50 |
|
51 |
title = 'Andy Warhol like Image Generator'
|
52 |
-
st.set_page_config(page_title=title, layout='centered')
|
53 |
st.title(title)
|
54 |
uploaded_file = st.file_uploader('Choose an image file')
|
55 |
if uploaded_file is None: uploaded_file = './sample.jpg'
|
56 |
|
57 |
if uploaded_file is not None:
|
58 |
im = Image.open(uploaded_file)
|
59 |
-
im.thumbnail((1000, 1000),resample=Image.BICUBIC) # resize
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
thresh, _img = cv2.threshold(im_gray, 0, 255, cv2.THRESH_OTSU)
|
64 |
|
65 |
n_rows, n_cols = st.number_input('Rows', value=3), st.number_input('Columns', value=3)
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
colors1, colors2 = randomize_palette_colors(n_rows, n_cols)
|
70 |
-
thresh = st.slider('Threshold', value=thresh, min_value=0.0, max_value=255.0)
|
71 |
|
72 |
if st.button('Shuffle colors'):
|
73 |
-
|
74 |
-
|
75 |
if True or st.button('Generate'):
|
76 |
-
im_bool = im_gray > thresh
|
77 |
-
|
78 |
ims_generated = []
|
|
|
79 |
for row in range(n_rows):
|
80 |
for col in range(n_cols):
|
81 |
i_color = n_cols * row + col
|
82 |
-
|
83 |
ims_col = np.empty((*im_gray.shape, 3))
|
84 |
for i in range(3): # RGB
|
85 |
-
ims_col[:, :, i] = (im_gray
|
|
|
86 |
if col == 0:
|
87 |
im_col_concat = Image.fromarray(ims_col.astype(np.uint8))
|
88 |
else:
|
|
|
6 |
import time
|
7 |
import seaborn as sns
|
8 |
|
9 |
+
from cv_funcs import *
|
10 |
+
from torchvision_funcs import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
@st.cache
|
13 |
def show_generated_image(image):
|
14 |
st.image(image)
|
15 |
+
|
16 |
@st.cache(suppress_st_warning=True)
|
17 |
+
def randomize_palette_colors(n_rows, n_cols, palettes=['Set1', 'Set3', 'Spectral'], seed=time.time(), n_times=10):
|
18 |
random.seed(seed)
|
19 |
+
colors = [sns.color_palette(palette, n_rows*n_cols*n_times) for palette in palettes]
|
20 |
+
_ = [random.shuffle(color) for color in colors]
|
21 |
+
return colors
|
|
|
22 |
|
23 |
@st.cache(suppress_st_warning=True)
|
24 |
+
def remove_image_background(image):
|
25 |
+
return deeplabv3_remove_bg(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
title = 'Andy Warhol like Image Generator'
|
28 |
+
st.set_page_config(page_title=title, page_icon='favicon.jpeg', layout='centered')
|
29 |
st.title(title)
|
30 |
uploaded_file = st.file_uploader('Choose an image file')
|
31 |
if uploaded_file is None: uploaded_file = './sample.jpg'
|
32 |
|
33 |
if uploaded_file is not None:
|
34 |
im = Image.open(uploaded_file)
|
35 |
+
im.thumbnail((1000, 1000),resample=Image.BICUBIC) # resize
|
36 |
+
|
37 |
+
is_masked = st.checkbox('With background masking? (3 colors)')
|
38 |
+
if is_masked:
|
39 |
+
im_masked, index_masked = remove_image_background(im)
|
40 |
+
st.image(im_masked, caption='Masked image')
|
41 |
+
else: st.image(im, caption='Original')
|
42 |
+
|
43 |
+
im_gray = np.array(im.convert('L'))
|
44 |
thresh, _img = cv2.threshold(im_gray, 0, 255, cv2.THRESH_OTSU)
|
45 |
|
46 |
n_rows, n_cols = st.number_input('Rows', value=3), st.number_input('Columns', value=3)
|
47 |
+
|
48 |
+
thresh = st.slider('Threshold', value=thresh, min_value=0.0, max_value=255.0)
|
49 |
+
colors = randomize_palette_colors(n_rows, n_cols, seed=0)
|
|
|
|
|
50 |
|
51 |
if st.button('Shuffle colors'):
|
52 |
+
colors = randomize_palette_colors(n_rows, n_cols, seed=time.time())
|
53 |
+
|
54 |
if True or st.button('Generate'):
|
|
|
|
|
55 |
ims_generated = []
|
56 |
+
|
57 |
for row in range(n_rows):
|
58 |
for col in range(n_cols):
|
59 |
i_color = n_cols * row + col
|
60 |
+
rgbs = [np.array(color[i_color])*np.array([255, 255, 255]).tolist() for color in colors]
|
61 |
ims_col = np.empty((*im_gray.shape, 3))
|
62 |
for i in range(3): # RGB
|
63 |
+
ims_col[:, :, i] = (im_gray <= thresh) * rgbs[0][i] + (im_gray > thresh) * rgbs[1][i]
|
64 |
+
if is_masked: ims_col[:, :, i][index_masked] = rgbs[2][i]
|
65 |
if col == 0:
|
66 |
im_col_concat = Image.fromarray(ims_col.astype(np.uint8))
|
67 |
else:
|
cv_funcs.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def get_concat_h(im1, im2):
|
6 |
+
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
|
7 |
+
dst.paste(im1, (0, 0))
|
8 |
+
dst.paste(im2, (im1.width, 0))
|
9 |
+
return dst
|
10 |
+
|
11 |
+
def get_concat_v(im1, im2):
|
12 |
+
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
|
13 |
+
dst.paste(im1, (0, 0))
|
14 |
+
dst.paste(im2, (0, im1.height))
|
15 |
+
return dst
|
16 |
+
|
17 |
+
def hsv_to_rgb(h, s, v):
|
18 |
+
bgr = cv2.cvtColor(np.array([[[h, s, v]]], dtype=np.uint8), cv2.COLOR_HSV2BGR)[0][0]
|
19 |
+
return [bgr[2]/255, bgr[1]/255, bgr[0]/255]
|
20 |
+
|
21 |
+
# def remove_bg(
|
22 |
+
# path,
|
23 |
+
# BLUR = 21,
|
24 |
+
# CANNY_THRESH_1 = 10,
|
25 |
+
# CANNY_THRESH_2 = 200,
|
26 |
+
# MASK_DILATE_ITER = 10,
|
27 |
+
# MASK_ERODE_ITER = 10,
|
28 |
+
# MASK_COLOR = (0.0,0.0,1.0),
|
29 |
+
# ):
|
30 |
+
# img = cv2.imread(path)
|
31 |
+
# gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
|
32 |
+
|
33 |
+
# edges = cv2.Canny(gray, CANNY_THRESH_1, CANNY_THRESH_2)
|
34 |
+
# edges = cv2.dilate(edges, None)
|
35 |
+
# edges = cv2.erode(edges, None)
|
36 |
+
|
37 |
+
# contour_info = []
|
38 |
+
# contours, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
|
39 |
+
# for c in contours:
|
40 |
+
# contour_info.append((
|
41 |
+
# c,
|
42 |
+
# cv2.isContourConvex(c),
|
43 |
+
# cv2.contourArea(c),
|
44 |
+
# ))
|
45 |
+
# contour_info = sorted(contour_info, key=lambda c: c[2], reverse=True)
|
46 |
+
# max_contour = contour_info[0]
|
47 |
+
|
48 |
+
# mask = np.zeros(edges.shape)
|
49 |
+
# cv2.fillConvexPoly(mask, max_contour[0], (255))
|
50 |
+
|
51 |
+
# mask = cv2.dilate(mask, None, iterations=MASK_DILATE_ITER)
|
52 |
+
# mask = cv2.erode(mask, None, iterations=MASK_ERODE_ITER)
|
53 |
+
# mask = cv2.GaussianBlur(mask, (BLUR, BLUR), 0)
|
54 |
+
# mask_stack = np.dstack([mask]*3) # Create 3-channel alpha mask
|
55 |
+
|
56 |
+
# mask_stack = mask_stack.astype('float32') / 255.0 # Use float matrices,
|
57 |
+
# img = img.astype('float32') / 255.0 # for easy blending
|
58 |
+
|
59 |
+
# masked = (mask_stack * img) + ((1-mask_stack) * MASK_COLOR) # Blend
|
60 |
+
# masked = (masked * 255).astype('uint8') # Convert back to 8-bit
|
61 |
+
|
62 |
+
# c_blue, c_green, c_red = cv2.split(img)
|
63 |
+
|
64 |
+
# img_a = cv2.merge((c_red, c_green, c_blue, mask.astype('float32') / 255.0))
|
65 |
+
# index = np.where(img_a[:, :, 3] == 0)
|
66 |
+
# #img_a[index] = [1.0, 1.0, 1.0, 1.0]
|
67 |
+
# return img_a
|
favicon.jpeg
ADDED
![]() |
icon.jpeg
DELETED
Binary file (254 kB)
|
|
requirements.txt
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
streamlit==0.76.0
|
2 |
Pillow
|
3 |
opencv-python
|
4 |
-
seaborn
|
|
|
|
|
|
1 |
streamlit==0.76.0
|
2 |
Pillow
|
3 |
opencv-python
|
4 |
+
seaborn
|
5 |
+
http://download.pytorch.org/whl/cpu/torch-1.4.0%2Bcpu-cp36-cp36m-linux_x86_64.whl
|
6 |
+
torchvision==0.5.0
|
torchvision_funcs.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
def deeplabv3_remove_bg(img):
|
10 |
+
img = np.array(img, dtype=np.uint8)
|
11 |
+
# img = cv2.imread(image_path)
|
12 |
+
# img = img[...,::-1] #BGR->RGB
|
13 |
+
h,w,_ = img.shape
|
14 |
+
# img = cv2.resize(img,(1000,1000))
|
15 |
+
|
16 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
17 |
+
|
18 |
+
model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
|
19 |
+
model = model.to(device)
|
20 |
+
model.eval();
|
21 |
+
|
22 |
+
preprocess = transforms.Compose([
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
25 |
+
])
|
26 |
+
input_tensor = preprocess(img)
|
27 |
+
input_batch = input_tensor.unsqueeze(0).to(device)
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
output = model(input_batch)['out'][0]
|
31 |
+
output = output.argmax(0)
|
32 |
+
mask = output.byte().cpu().numpy()
|
33 |
+
# mask = cv2.resize(mask,(w,h))
|
34 |
+
# img = cv2.resize(img,(w,h))
|
35 |
+
mask[mask>0] = 1.0 # NOTE: なぜか3が入っていたので
|
36 |
+
mask = np.dstack([mask, mask, mask])
|
37 |
+
img_masked = Image.fromarray(cv2.multiply(img, mask))
|
38 |
+
index_masked = np.where(np.array(mask)[:,:,2]==0)
|
39 |
+
return img_masked, index_masked
|
40 |
+
|
41 |
+
|