Spaces:
Sleeping
Sleeping
import spaces | |
import subprocess | |
# def install_mmcv(): | |
# try: | |
# subprocess.run([ | |
# "pip", "install", "mmcv-full==1.7.2", | |
# "-f", "https://download.openmmlab.com/mmcv/dist/cu121/torch2.1.0/" | |
# ], check=True) | |
# except subprocess.CalledProcessError as e: | |
# print("Failed to install mmcv-full:", e) | |
# install_mmcv() | |
import mmcv | |
import gradio as gr | |
import numpy as np | |
import torch | |
import os | |
import cv2 | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from net.dornet import Net | |
from net.dornet_ddp import Net_ddp | |
print("=" * 50) | |
print(f"CUDA available: {torch.cuda.is_available()}") | |
print(f"torch version: {torch.__version__}") | |
print(f"CUDA version: {torch.version.cuda}") | |
print(f"mmcv version: {mmcv.__version__}") | |
print("=" * 50) | |
# init | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
device = "cpu" | |
net = Net(tiny_model=False).to(device) | |
model_ckpt_map = { | |
"RGB-D-D": "./checkpoints/RGBDD.pth", | |
"TOFDSR": "./checkpoints/TOFDSR.pth" | |
} | |
# load model | |
def load_model(model_type: str): | |
global net | |
ckpt_path = model_ckpt_map[model_type] | |
print(f"Loading weights from: {ckpt_path}") | |
if model_type == "RGB-D-D": | |
net = Net(tiny_model=False).to(device) | |
elif model_type == "TOFDSR": | |
net = Net_ddp(tiny_model=False).srn.to(device) | |
else: | |
raise ValueError(f"Unknown model_type: {model_type}") | |
net.load_state_dict(torch.load(ckpt_path, map_location=device)) | |
net.eval() | |
load_model("RGB-D-D") | |
# data process | |
def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image): | |
image = np.array(rgb_image.convert("RGB")).astype(np.float32) | |
h, w, _ = image.shape | |
lr = np.array(lr_depth.resize((w, h), Image.BICUBIC)).astype(np.float32) | |
# Normalize depth | |
max_out, min_out = 5000.0, 0.0 | |
lr = (lr - min_out) / (max_out - min_out) | |
# Normalize RGB | |
maxx, minn = np.max(image), np.min(image) | |
image = (image - minn) / (maxx - minn) | |
# To tensor | |
data_transform = transforms.Compose([transforms.ToTensor()]) | |
image = data_transform(image).float() | |
lr = data_transform(np.expand_dims(lr, 2)).float() | |
# Add batch dimension | |
lr = lr.unsqueeze(0).to(device) | |
image = image.unsqueeze(0).to(device) | |
return image, lr, min_out, max_out | |
# model inference | |
def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str): | |
load_model(model_type) # reset weight | |
image, lr, min_out, max_out = preprocess_inputs(rgb_image, lr_depth) | |
if model_type == "RGB-D-D": | |
out = net(x_query=lr, rgb=image) | |
elif model_type == "TOFDSR": | |
out, _ = net(x_query=lr, rgb=image) | |
pred = out[0, 0] * (max_out - min_out) + min_out | |
pred = pred.cpu().numpy().astype(np.uint16) | |
# raw | |
pred_gray = Image.fromarray(pred) | |
# heat | |
pred_norm = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) * 255 | |
pred_vis = pred_norm.astype(np.uint8) | |
pred_heat = cv2.applyColorMap(pred_vis, cv2.COLORMAP_PLASMA) | |
pred_heat = cv2.cvtColor(pred_heat, cv2.COLOR_BGR2RGB) | |
# return pred_gray, Image.fromarray(pred_heat) | |
return Image.fromarray(pred_heat) | |
Intro = """ | |
## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution | |
[π Paper](https://arxiv.org/pdf/2410.11666) β’ [π» Code](https://github.com/yanzq95/DORNet) β’ [π¦ Model](https://huggingface.co/wzxwyx/DORNet/tree/main) | |
""" | |
with gr.Blocks(css=""" | |
.output-image { | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
} | |
.output-image img { | |
margin: auto; | |
display: block; | |
} | |
""") as demo: | |
gr.Markdown(Intro) | |
with gr.Row(): | |
with gr.Column(): | |
rgb_input = gr.Image(label="RGB Image", type="pil") | |
lr_input = gr.Image(label="Low-res Depth", type="pil", image_mode="I") | |
with gr.Column(): | |
# output1 = gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"]) | |
output2 = gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"]) | |
model_selector = gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D") | |
run_button = gr.Button("Run Inference") | |
gr.Examples( | |
examples=[ | |
["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"], | |
["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"], | |
], | |
inputs=[rgb_input, lr_input, model_selector], | |
# outputs=[output1, output2], | |
outputs=[output2], | |
label="Try Examples β" | |
) | |
# run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output1, output2]) | |
run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output2]) | |
demo.launch() |