Spaces:
Running
Running
| import numpy as np | |
| import cv2 | |
| import onnxruntime | |
| import gradio as gr | |
| article_text = """ | |
| <div style="text-align: center;"> | |
| <p>Enjoying the tool? Buy me a coffee and get exclusive prompt guides!</p> | |
| <p><i>Instantly unlock helpful tips for creating better prompts!</i></p> | |
| <div style="display: flex; justify-content: center;"> | |
| <a href="https://piczify.lemonsqueezy.com/buy/0f5206fa-68e8-42f6-9ca8-4f80c587c83e"> | |
| <img src="https://www.buymeacoffee.com/assets/img/custom_images/yellow_img.png" | |
| alt="Buy Me a Coffee" | |
| style="height: 40px; width: auto; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); border-radius: 10px;"> | |
| </a> | |
| </div> | |
| </div> | |
| """ | |
| def pre_process(img: np.array) -> np.array: | |
| # H, W, C -> C, H, W | |
| img = np.transpose(img[:, :, 0:3], (2, 0, 1)) | |
| # C, H, W -> 1, C, H, W | |
| img = np.expand_dims(img, axis=0).astype(np.float32) | |
| return img | |
| def post_process(img: np.array) -> np.array: | |
| # 1, C, H, W -> C, H, W | |
| img = np.squeeze(img) | |
| # C, H, W -> H, W, C | |
| img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8) | |
| return img | |
| def inference(model_path: str, img_array: np.array) -> np.array: | |
| options = onnxruntime.SessionOptions() | |
| options.intra_op_num_threads = 1 | |
| options.inter_op_num_threads = 1 | |
| ort_session = onnxruntime.InferenceSession(model_path, options) | |
| ort_inputs = {ort_session.get_inputs()[0].name: img_array} | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| return ort_outs[0] | |
| def convert_pil_to_cv2(image): | |
| # pil_image = image.convert("RGB") | |
| open_cv_image = np.array(image) | |
| # RGB to BGR | |
| open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| return open_cv_image | |
| def upscale(image, model): | |
| model_path = f"models/{model}.ort" | |
| img = convert_pil_to_cv2(image) | |
| if img.ndim == 2: | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| if img.shape[2] == 4: | |
| alpha = img[:, :, 3] # GRAY | |
| alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR | |
| alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR | |
| alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY | |
| img = img[:, :, 0:3] # BGR | |
| image_output = post_process(inference(model_path, pre_process(img))) # BGR | |
| image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA | |
| image_output[:, :, 3] = alpha_output | |
| elif img.shape[2] == 3: | |
| image_output = post_process(inference(model_path, pre_process(img))) # BGR | |
| return image_output | |
| css = ".output-image, .input-image, .image-preview {height: 480px !important} " | |
| model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"] | |
| gr.Interface( | |
| fn=upscale, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Radio( | |
| model_choices, | |
| type="value", | |
| value="modelx4", | |
| label="Choose Upscaler", | |
| ) | |
| ], | |
| # additional_inputs=[ | |
| # gr.Radio( | |
| # model_choices, | |
| # type="value", | |
| # value="modelx4", | |
| # label="Choose Upscaler", | |
| # ) | |
| # ], | |
| outputs="image", | |
| # title="Image Upscaler PRO ⚡", | |
| # description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)", | |
| description = """ | |
| <div style="text-align: center;"> | |
| <h1>Image Upscaler PRO ⚡</h1> | |
| <a href="https://arxiv.org/abs/2105.09750"> | |
| <img src="https://img.shields.io/badge/arXiv-2105.09750-b31b1b.svg" alt="Arxiv" style="display:inline-block;"> | |
| </a> | |
| <p>Anchor-based Plain Net for Mobile Image Super-Resolution</p> | |
| </div> | |
| """, | |
| article =article_text, | |
| allow_flagging="never", | |
| css=css, | |
| ).launch() | |