DORNet / app.py
RaynWu2002's picture
Update app.py
09ef98b verified
import spaces
import subprocess
# def install_mmcv():
# try:
# subprocess.run([
# "pip", "install", "mmcv-full==1.7.2",
# "-f", "https://download.openmmlab.com/mmcv/dist/cu121/torch2.1.0/"
# ], check=True)
# except subprocess.CalledProcessError as e:
# print("Failed to install mmcv-full:", e)
# install_mmcv()
import mmcv
import gradio as gr
import numpy as np
import torch
import os
import cv2
from PIL import Image
import torchvision.transforms as transforms
from net.dornet import Net
from net.dornet_ddp import Net_ddp
print("=" * 50)
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"torch version: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")
print(f"mmcv version: {mmcv.__version__}")
print("=" * 50)
# init
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"
net = Net(tiny_model=False).to(device)
model_ckpt_map = {
"RGB-D-D": "./checkpoints/RGBDD.pth",
"TOFDSR": "./checkpoints/TOFDSR.pth"
}
# load model
@spaces.GPU
def load_model(model_type: str):
global net
ckpt_path = model_ckpt_map[model_type]
print(f"Loading weights from: {ckpt_path}")
if model_type == "RGB-D-D":
net = Net(tiny_model=False).to(device)
elif model_type == "TOFDSR":
net = Net_ddp(tiny_model=False).srn.to(device)
else:
raise ValueError(f"Unknown model_type: {model_type}")
net.load_state_dict(torch.load(ckpt_path, map_location=device))
net.eval()
load_model("RGB-D-D")
# data process
@spaces.GPU
def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
image = np.array(rgb_image.convert("RGB")).astype(np.float32)
h, w, _ = image.shape
lr = np.array(lr_depth.resize((w, h), Image.BICUBIC)).astype(np.float32)
# Normalize depth
max_out, min_out = 5000.0, 0.0
lr = (lr - min_out) / (max_out - min_out)
# Normalize RGB
maxx, minn = np.max(image), np.min(image)
image = (image - minn) / (maxx - minn)
# To tensor
data_transform = transforms.Compose([transforms.ToTensor()])
image = data_transform(image).float()
lr = data_transform(np.expand_dims(lr, 2)).float()
# Add batch dimension
lr = lr.unsqueeze(0).to(device)
image = image.unsqueeze(0).to(device)
return image, lr, min_out, max_out
# model inference
@spaces.GPU
@torch.no_grad()
def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str):
load_model(model_type) # reset weight
image, lr, min_out, max_out = preprocess_inputs(rgb_image, lr_depth)
if model_type == "RGB-D-D":
out = net(x_query=lr, rgb=image)
elif model_type == "TOFDSR":
out, _ = net(x_query=lr, rgb=image)
pred = out[0, 0] * (max_out - min_out) + min_out
pred = pred.cpu().numpy().astype(np.uint16)
# raw
pred_gray = Image.fromarray(pred)
# heat
pred_norm = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) * 255
pred_vis = pred_norm.astype(np.uint8)
pred_heat = cv2.applyColorMap(pred_vis, cv2.COLORMAP_PLASMA)
pred_heat = cv2.cvtColor(pred_heat, cv2.COLOR_BGR2RGB)
# return pred_gray, Image.fromarray(pred_heat)
return Image.fromarray(pred_heat)
Intro = """
## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution
[πŸ“„ Paper](https://arxiv.org/pdf/2410.11666) β€’ [πŸ’» Code](https://github.com/yanzq95/DORNet) β€’ [πŸ“¦ Model](https://huggingface.co/wzxwyx/DORNet/tree/main)
"""
with gr.Blocks(css="""
.output-image {
display: flex;
justify-content: center;
align-items: center;
}
.output-image img {
margin: auto;
display: block;
}
""") as demo:
gr.Markdown(Intro)
with gr.Row():
with gr.Column():
rgb_input = gr.Image(label="RGB Image", type="pil")
lr_input = gr.Image(label="Low-res Depth", type="pil", image_mode="I")
with gr.Column():
# output1 = gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"])
output2 = gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"])
model_selector = gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D")
run_button = gr.Button("Run Inference")
gr.Examples(
examples=[
["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"],
["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"],
],
inputs=[rgb_input, lr_input, model_selector],
# outputs=[output1, output2],
outputs=[output2],
label="Try Examples ↓"
)
# run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output1, output2])
run_button.click(fn=infer, inputs=[rgb_input, lr_input, model_selector], outputs=[output2])
demo.launch()