import spaces import subprocess # def install_mmcv(): # try: # subprocess.run([ # "pip", "install", "mmcv-full==1.7.2", # "-f", "https://download.openmmlab.com/mmcv/dist/cu121/torch2.1.0/" # ], check=True) # except subprocess.CalledProcessError as e: # print("Failed to install mmcv-full:", e) # install_mmcv() import mmcv import gradio as gr import numpy as np import torch import os import cv2 from PIL import Image import torchvision.transforms as transforms from net.dornet import Net from net.dornet_ddp import Net_ddp print("=" * 50) print(f"CUDA available: {torch.cuda.is_available()}") print(f"torch version: {torch.__version__}") print(f"CUDA version: {torch.version.cuda}") print(f"mmcv version: {mmcv.__version__}") print("=" * 50) # init device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = "cpu" net = Net(tiny_model=False).to(device) model_ckpt_map = { "RGB-D-D": "./checkpoints/RGBDD.pth", "TOFDSR": "./checkpoints/TOFDSR.pth" } # load model @spaces.GPU def load_model(model_type: str): global net ckpt_path = model_ckpt_map[model_type] print(f"Loading weights from: {ckpt_path}") if model_type == "RGB-D-D": net = Net(tiny_model=False).to(device) elif model_type == "TOFDSR": net = Net_ddp(tiny_model=False).srn.to(device) else: raise ValueError(f"Unknown model_type: {model_type}") net.load_state_dict(torch.load(ckpt_path, map_location=device)) net.eval() load_model("RGB-D-D") # data process @spaces.GPU def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image): image = np.array(rgb_image.convert("RGB")).astype(np.float32) h, w, _ = image.shape lr = np.array(lr_depth.resize((w, h), Image.BICUBIC)).astype(np.float32) # Normalize depth max_out, min_out = 5000.0, 0.0 lr = (lr - min_out) / (max_out - min_out) # Normalize RGB maxx, minn = np.max(image), np.min(image) image = (image - minn) / (maxx - minn) # To tensor data_transform = transforms.Compose([transforms.ToTensor()]) image = data_transform(image).float() lr = data_transform(np.expand_dims(lr, 2)).float() # Add batch dimension lr = lr.unsqueeze(0).to(device) image = image.unsqueeze(0).to(device) return image, lr, min_out, max_out # model inference @spaces.GPU @torch.no_grad() def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str): load_model(model_type) # reset weight image, lr, min_out, max_out = preprocess_inputs(rgb_image, lr_depth) if model_type == "RGB-D-D": out = net(x_query=lr, rgb=image) elif model_type == "TOFDSR": out, _ = net(x_query=lr, rgb=image) pred = out[0, 0] * (max_out - min_out) + min_out pred = pred.cpu().numpy().astype(np.uint16) # raw pred_gray = Image.fromarray(pred) # heat pred_norm = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) * 255 pred_vis = pred_norm.astype(np.uint8) pred_heat = cv2.applyColorMap(pred_vis, cv2.COLORMAP_PLASMA) pred_heat = cv2.cvtColor(pred_heat, cv2.COLOR_BGR2RGB) # return pred_gray, Image.fromarray(pred_heat) return Image.fromarray(pred_heat) Intro = """ ## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution [📄 Paper](https://arxiv.org/pdf/2410.11666) • [💻 Code](https://github.com/yanzq95/DORNet) • [📦 Model](https://huggingface.co/wzxwyx/DORNet/tree/main) """ with gr.Blocks(css=""" .output-image { display: flex; justify-content: center; align-items: center; } .output-image img { margin: auto; display: block; } """) as demo: gr.Markdown(Intro) with gr.Row(): with gr.Column(): rgb_input = gr.Image(label="RGB Image", type="pil") lr_input = gr.Image(label="Low-res Depth", type="pil", image_mode="I") with gr.Column(): # output1 = gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"]) output2 = gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"]) model_selector = gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D") run_button = gr.Button("Run Inference") gr.Examples( examples=[ ["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"], ["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"], ], inputs=[rgb_input, lr_input, model_selector], # outputs=[output1, output2], outputs=[output2], label="Try Examples ↓" ) # run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output1, output2]) run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output2]) demo.launch()