import streamlit as st from PIL import Image, ImageOps import cv2 import numpy as np import random import time import seaborn as sns from cv_funcs import * from torchvision_funcs import * from backgroundremover.utilities import download_download_files_from_github import torch import os from hsh.library.hash import Hasher from torchvision import transforms def load_model(model_name: str = "u2net"): hasher = Hasher() model = { 'u2netp': (u2net.U2NETP, 'e4f636406ca4e2af789941e7f139ee2e', '1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy', 'U2NET_PATH'), 'u2net': (u2net.U2NET, '09fb4e49b7f785c9f855baf94916840a', '1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ', 'U2NET_PATH'), 'u2net_human_seg': (u2net.U2NET, '347c3d51b01528e5c6c071e3cff1cb55', '1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P', 'U2NET_PATH') }[model_name] if model_name == "u2net": net = u2net.U2NET(3, 1) path = os.environ.get( "U2NET_PATH", os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), ) if ( not os.path.exists(path) or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a" ): download_downloadfiles_from_github( path, model_name ) else: print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr) try: if torch.cuda.is_available(): net.load_state_dict(torch.load(path)) net.to(torch.device("cuda")) else: net.load_state_dict( torch.load( path, map_location="cpu", ) ) except FileNotFoundError: raise FileNotFoundError( errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth" ) net.eval() return net def norm_pred(d): ma = torch.max(d) mi = torch.min(d) dn = (d - mi) / (ma - mi) return dn def preprocess(image): label_3 = np.zeros(image.shape) label = np.zeros(label_3.shape[0:2]) if 3 == len(label_3.shape): label = label_3[:, :, 0] elif 2 == len(label_3.shape): label = label_3 if 3 == len(image.shape) and 2 == len(label.shape): label = label[:, :, np.newaxis] elif 2 == len(image.shape) and 2 == len(label.shape): image = image[:, :, np.newaxis] label = label[:, :, np.newaxis] transform = transforms.Compose( [data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)] ) sample = transform({"imidx": np.array([0]), "image": image, "label": label}) return sample def predict(net, item): sample = preprocess(item) with torch.no_grad(): if torch.cuda.is_available(): inputs_test = torch.cuda.FloatTensor( sample["image"].unsqueeze(0).cuda().float() ) else: inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float()) d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) pred = d1[:, 0, :, :] predict = norm_pred(pred) predict = predict.squeeze() predict_np = predict.cpu().detach().numpy() img = Image.fromarray(predict_np * 255).convert("RGB") del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample return img def remove_bg(img): img_arry = np.array(img) model = load_model(model_name="u2net") mask = predict(model, img_arry) mask = mask.resize(img.size) mask_arry = np.array(mask) mask_arry[mask_arry>0] = 1 img_masked = Image.fromarray(cv2.multiply(img_arry, mask_arry)) index_masked = np.where(np.array(mask)==0) return img_masked, index_masked @st.cache def show_generated_image(image): st.image(image) @st.cache(suppress_st_warning=True) def randomize_palette_colors(n_rows, n_cols, palettes=['Set1', 'Set3', 'Spectral'], seed=time.time(), n_times=10): random.seed(seed) colors = [sns.color_palette(palette, n_rows*n_cols*n_times) for palette in palettes] _ = [random.shuffle(color) for color in colors] return colors @st.cache(suppress_st_warning=True) def remove_image_background(image): #return deeplabv3_remove_bg(image) return remove_bg(img) title = 'Andy Warhol like Image Generator' st.set_page_config(page_title=title, page_icon='favicon.jpeg', layout='centered') st.title(title) uploaded_file = st.file_uploader('Choose an image file') if uploaded_file is None: uploaded_file = './sample.jpg' if uploaded_file is not None: im = Image.open(uploaded_file) im.thumbnail((1000, 1000),resample=Image.BICUBIC) # resize is_masked = st.checkbox('With background masking? (3 colors)') if is_masked: im_masked, index_masked = remove_image_background(im) st.image(im_masked, caption='Masked image') else: st.image(im, caption='Original') im_gray = np.array(im.convert('L')) thresh, _img = cv2.threshold(im_gray, 0, 255, cv2.THRESH_OTSU) n_rows, n_cols = st.number_input('Rows', value=3), st.number_input('Columns', value=3) thresh = st.slider('Threshold', value=thresh, min_value=0.0, max_value=255.0) colors = randomize_palette_colors(n_rows, n_cols, seed=0) if st.button('Shuffle colors'): colors = randomize_palette_colors(n_rows, n_cols, seed=time.time()) if True or st.button('Generate'): ims_generated = [] for row in range(n_rows): for col in range(n_cols): i_color = n_cols * row + col rgbs = [np.array(color[i_color])*np.array([255, 255, 255]).tolist() for color in colors] ims_col = np.empty((*im_gray.shape, 3)) for i in range(3): # RGB ims_col[:, :, i] = (im_gray <= thresh) * rgbs[0][i] + (im_gray > thresh) * rgbs[1][i] if is_masked: ims_col[:, :, i][index_masked] = rgbs[2][i] if col == 0: im_col_concat = Image.fromarray(ims_col.astype(np.uint8)) else: im_col_concat = get_concat_h(im_col_concat, Image.fromarray(ims_col.astype(np.uint8))) if row == 0: im_generated = im_col_concat else: im_generated = get_concat_v(im_generated, im_col_concat) # if 'im_generated' in locals(): st.image(im_generated)