sentencebird's picture
Update app.py
fc866a2
import streamlit as st
from PIL import Image, ImageOps
import cv2
import numpy as np
import random
import time
import seaborn as sns
from cv_funcs import *
from torchvision_funcs import *
from backgroundremover.utilities import download_download_files_from_github
import torch
import os
from hsh.library.hash import Hasher
from torchvision import transforms
def load_model(model_name: str = "u2net"):
hasher = Hasher()
model = {
'u2netp': (u2net.U2NETP,
'e4f636406ca4e2af789941e7f139ee2e',
'1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy',
'U2NET_PATH'),
'u2net': (u2net.U2NET,
'09fb4e49b7f785c9f855baf94916840a',
'1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ',
'U2NET_PATH'),
'u2net_human_seg': (u2net.U2NET,
'347c3d51b01528e5c6c071e3cff1cb55',
'1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P',
'U2NET_PATH')
}[model_name]
if model_name == "u2net":
net = u2net.U2NET(3, 1)
path = os.environ.get(
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
):
download_downloadfiles_from_github(
path, model_name
)
else:
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
try:
if torch.cuda.is_available():
net.load_state_dict(torch.load(path))
net.to(torch.device("cuda"))
else:
net.load_state_dict(
torch.load(
path,
map_location="cpu",
)
)
except FileNotFoundError:
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
)
net.eval()
return net
def norm_pred(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d - mi) / (ma - mi)
return dn
def preprocess(image):
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if 3 == len(label_3.shape):
label = label_3[:, :, 0]
elif 2 == len(label_3.shape):
label = label_3
if 3 == len(image.shape) and 2 == len(label.shape):
label = label[:, :, np.newaxis]
elif 2 == len(image.shape) and 2 == len(label.shape):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
transform = transforms.Compose(
[data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)]
)
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
return sample
def predict(net, item):
sample = preprocess(item)
with torch.no_grad():
if torch.cuda.is_available():
inputs_test = torch.cuda.FloatTensor(
sample["image"].unsqueeze(0).cuda().float()
)
else:
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
pred = d1[:, 0, :, :]
predict = norm_pred(pred)
predict = predict.squeeze()
predict_np = predict.cpu().detach().numpy()
img = Image.fromarray(predict_np * 255).convert("RGB")
del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample
return img
def remove_bg(img):
img_arry = np.array(img)
model = load_model(model_name="u2net")
mask = predict(model, img_arry)
mask = mask.resize(img.size)
mask_arry = np.array(mask)
mask_arry[mask_arry>0] = 1
img_masked = Image.fromarray(cv2.multiply(img_arry, mask_arry))
index_masked = np.where(np.array(mask)==0)
return img_masked, index_masked
@st.cache
def show_generated_image(image):
st.image(image)
@st.cache(suppress_st_warning=True)
def randomize_palette_colors(n_rows, n_cols, palettes=['Set1', 'Set3', 'Spectral'], seed=time.time(), n_times=10):
random.seed(seed)
colors = [sns.color_palette(palette, n_rows*n_cols*n_times) for palette in palettes]
_ = [random.shuffle(color) for color in colors]
return colors
@st.cache(suppress_st_warning=True)
def remove_image_background(image):
#return deeplabv3_remove_bg(image)
return remove_bg(img)
title = 'Andy Warhol like Image Generator'
st.set_page_config(page_title=title, page_icon='favicon.jpeg', layout='centered')
st.title(title)
uploaded_file = st.file_uploader('Choose an image file')
if uploaded_file is None: uploaded_file = './sample.jpg'
if uploaded_file is not None:
im = Image.open(uploaded_file)
im.thumbnail((1000, 1000),resample=Image.BICUBIC) # resize
is_masked = st.checkbox('With background masking? (3 colors)')
if is_masked:
im_masked, index_masked = remove_image_background(im)
st.image(im_masked, caption='Masked image')
else: st.image(im, caption='Original')
im_gray = np.array(im.convert('L'))
thresh, _img = cv2.threshold(im_gray, 0, 255, cv2.THRESH_OTSU)
n_rows, n_cols = st.number_input('Rows', value=3), st.number_input('Columns', value=3)
thresh = st.slider('Threshold', value=thresh, min_value=0.0, max_value=255.0)
colors = randomize_palette_colors(n_rows, n_cols, seed=0)
if st.button('Shuffle colors'):
colors = randomize_palette_colors(n_rows, n_cols, seed=time.time())
if True or st.button('Generate'):
ims_generated = []
for row in range(n_rows):
for col in range(n_cols):
i_color = n_cols * row + col
rgbs = [np.array(color[i_color])*np.array([255, 255, 255]).tolist() for color in colors]
ims_col = np.empty((*im_gray.shape, 3))
for i in range(3): # RGB
ims_col[:, :, i] = (im_gray <= thresh) * rgbs[0][i] + (im_gray > thresh) * rgbs[1][i]
if is_masked: ims_col[:, :, i][index_masked] = rgbs[2][i]
if col == 0:
im_col_concat = Image.fromarray(ims_col.astype(np.uint8))
else:
im_col_concat = get_concat_h(im_col_concat, Image.fromarray(ims_col.astype(np.uint8)))
if row == 0:
im_generated = im_col_concat
else:
im_generated = get_concat_v(im_generated, im_col_concat)
# if 'im_generated' in locals():
st.image(im_generated)