RaynWu2002 commited on
Commit
a0aa5cc
·
verified ·
1 Parent(s): 5f4c60f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -1,18 +1,18 @@
1
  import subprocess
2
- def install_mmcv():
3
- try:
4
- subprocess.run([
5
- "pip", "install", "mmcv-full==1.7.2",
6
- "-f", "https://download.openmmlab.com/mmcv/dist/cu121/torch2.1.0/"
7
- ], check=True)
8
- except subprocess.CalledProcessError as e:
9
- print("Failed to install mmcv-full:", e)
10
-
11
- install_mmcv()
12
  import mmcv
13
  import gradio as gr
14
  import numpy as np
15
- # import spaces
16
  import torch
17
  import os
18
  import cv2
@@ -40,7 +40,7 @@ model_ckpt_map = {
40
  }
41
 
42
  # load model
43
- # @spaces.GPU
44
  def load_model(model_type: str):
45
  global net
46
  ckpt_path = model_ckpt_map[model_type]
@@ -59,7 +59,7 @@ load_model("RGB-D-D")
59
 
60
 
61
  # data process
62
- # @spaces.GPU
63
  def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
64
  image = np.array(rgb_image.convert("RGB")).astype(np.float32)
65
  h, w, _ = image.shape
@@ -81,7 +81,7 @@ def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
81
 
82
 
83
  # model inference
84
- # @spaces.GPU
85
  @torch.no_grad()
86
  def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str):
87
  load_model(model_type) # reset weight
 
1
  import subprocess
2
+ # def install_mmcv():
3
+ # try:
4
+ # subprocess.run([
5
+ # "pip", "install", "mmcv-full==1.7.2",
6
+ # "-f", "https://download.openmmlab.com/mmcv/dist/cu121/torch2.1.0/"
7
+ # ], check=True)
8
+ # except subprocess.CalledProcessError as e:
9
+ # print("Failed to install mmcv-full:", e)
10
+
11
+ # install_mmcv()
12
  import mmcv
13
  import gradio as gr
14
  import numpy as np
15
+ import spaces
16
  import torch
17
  import os
18
  import cv2
 
40
  }
41
 
42
  # load model
43
+ @spaces.GPU
44
  def load_model(model_type: str):
45
  global net
46
  ckpt_path = model_ckpt_map[model_type]
 
59
 
60
 
61
  # data process
62
+ @spaces.GPU
63
  def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
64
  image = np.array(rgb_image.convert("RGB")).astype(np.float32)
65
  h, w, _ = image.shape
 
81
 
82
 
83
  # model inference
84
+ @spaces.GPU
85
  @torch.no_grad()
86
  def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str):
87
  load_model(model_type) # reset weight