refactor: code and upgrade gradio
Browse files- app.py +16 -4
- functions.py +2 -9
- requirements.txt +34 -11
app.py
CHANGED
@@ -1,11 +1,23 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
title = "Water Body Segmentation - Image Segmentation PyTorch"
|
5 |
examples = ['examples/image1.png', 'examples/image2.png', 'examples/image3.png', 'examples/image4.png', 'examples/image5.png']
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
interface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import onnxruntime as rt
|
3 |
+
from functions import run_inference
|
4 |
+
|
5 |
+
model_path = 'weights/model.onnx'
|
6 |
+
session = rt.InferenceSession(model_path)
|
7 |
+
input_name = session.get_inputs()[0].name
|
8 |
+
output_name = session.get_outputs()[0].name
|
9 |
|
10 |
title = "Water Body Segmentation - Image Segmentation PyTorch"
|
11 |
examples = ['examples/image1.png', 'examples/image2.png', 'examples/image3.png', 'examples/image4.png', 'examples/image5.png']
|
12 |
|
13 |
+
def inference_wrapper(image):
|
14 |
+
return run_inference(image, session, input_name, output_name)
|
15 |
+
|
16 |
+
|
17 |
+
interface = gr.Interface(fn=inference_wrapper,
|
18 |
+
inputs=gr.Image(type='numpy', height=400, width=400),
|
19 |
+
outputs=gr.Image(type="numpy", height=400, width=400),
|
20 |
+
examples=examples,
|
21 |
+
title=title)
|
22 |
|
23 |
interface.launch()
|
functions.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import cv2
|
2 |
import numpy as np
|
3 |
-
import onnxruntime as rt
|
4 |
|
5 |
def resize_preserve_aspect_ratio(image, size):
|
6 |
h, w = image.shape[:2]
|
@@ -10,9 +9,7 @@ def resize_preserve_aspect_ratio(image, size):
|
|
10 |
image = cv2.resize(image, (size, size * h // w))
|
11 |
return image
|
12 |
|
13 |
-
def
|
14 |
-
|
15 |
-
model_path = 'weights/model.onnx'
|
16 |
inp_dim = inp_image.shape[:2]
|
17 |
|
18 |
image = cv2.resize(inp_image, (256, 256))
|
@@ -20,10 +17,6 @@ def predict(inp_image):
|
|
20 |
image = np.transpose(image, (2, 0, 1))
|
21 |
image = np.expand_dims(image, axis=0)
|
22 |
|
23 |
-
session = rt.InferenceSession(model_path)
|
24 |
-
input_name = session.get_inputs()[0].name
|
25 |
-
output_name = session.get_outputs()[0].name
|
26 |
-
|
27 |
pred_onx = session.run([output_name], {input_name: image.astype(np.float32)})[0]
|
28 |
pred_onx = pred_onx > 0.5
|
29 |
pred_onx = pred_onx * 255
|
@@ -32,5 +25,5 @@ def predict(inp_image):
|
|
32 |
pred_onx = np.expand_dims(pred_onx, axis=2)
|
33 |
pred_onx = np.concatenate((pred_onx, pred_onx, pred_onx), axis=2)
|
34 |
|
35 |
-
output = resize_preserve_aspect_ratio(pred_onx,
|
36 |
return output
|
|
|
1 |
import cv2
|
2 |
import numpy as np
|
|
|
3 |
|
4 |
def resize_preserve_aspect_ratio(image, size):
|
5 |
h, w = image.shape[:2]
|
|
|
9 |
image = cv2.resize(image, (size, size * h // w))
|
10 |
return image
|
11 |
|
12 |
+
def run_inference(inp_image, session, input_name, output_name):
|
|
|
|
|
13 |
inp_dim = inp_image.shape[:2]
|
14 |
|
15 |
image = cv2.resize(inp_image, (256, 256))
|
|
|
17 |
image = np.transpose(image, (2, 0, 1))
|
18 |
image = np.expand_dims(image, axis=0)
|
19 |
|
|
|
|
|
|
|
|
|
20 |
pred_onx = session.run([output_name], {input_name: image.astype(np.float32)})[0]
|
21 |
pred_onx = pred_onx > 0.5
|
22 |
pred_onx = pred_onx * 255
|
|
|
25 |
pred_onx = np.expand_dims(pred_onx, axis=2)
|
26 |
pred_onx = np.concatenate((pred_onx, pred_onx, pred_onx), axis=2)
|
27 |
|
28 |
+
output = resize_preserve_aspect_ratio(pred_onx, 400)
|
29 |
return output
|
requirements.txt
CHANGED
@@ -1,5 +1,8 @@
|
|
|
|
1 |
aiohttp==3.8.3
|
2 |
aiosignal==1.3.1
|
|
|
|
|
3 |
anyio==3.6.2
|
4 |
async-timeout==4.0.2
|
5 |
attrs==22.1.0
|
@@ -12,22 +15,29 @@ coloredlogs==15.0.1
|
|
12 |
contourpy==1.0.6
|
13 |
cryptography==38.0.3
|
14 |
cycler==0.11.0
|
15 |
-
fastapi==0.
|
16 |
ffmpy==0.3.0
|
|
|
17 |
flatbuffers==22.10.26
|
18 |
fonttools==4.38.0
|
19 |
frozenlist==1.3.3
|
20 |
-
fsspec==
|
21 |
-
gradio==
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
humanfriendly==10.0
|
26 |
idna==3.4
|
|
|
27 |
Jinja2==3.1.2
|
|
|
28 |
kiwisolver==1.4.4
|
29 |
linkify-it-py==1.0.3
|
30 |
-
markdown-it-py==
|
31 |
MarkupSafe==2.1.1
|
32 |
matplotlib==3.6.2
|
33 |
mdit-py-plugins==0.3.1
|
@@ -45,21 +55,34 @@ Pillow==9.3.0
|
|
45 |
protobuf==4.21.9
|
46 |
pycparser==2.21
|
47 |
pycryptodome==3.15.0
|
48 |
-
pydantic==
|
|
|
49 |
pydub==0.25.1
|
|
|
50 |
PyNaCl==1.5.0
|
51 |
pyparsing==3.0.9
|
|
|
52 |
python-dateutil==2.8.2
|
53 |
-
python-multipart==0.0.
|
54 |
pytz==2022.6
|
55 |
PyYAML==6.0
|
56 |
requests==2.28.1
|
57 |
rfc3986==1.5.0
|
|
|
|
|
|
|
|
|
|
|
58 |
six==1.16.0
|
59 |
sniffio==1.3.0
|
60 |
-
starlette==0.
|
61 |
sympy==1.11.1
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
uc-micro-py==1.0.1
|
64 |
urllib3==1.26.12
|
65 |
uvicorn==0.19.0
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
aiohttp==3.8.3
|
3 |
aiosignal==1.3.1
|
4 |
+
altair==5.3.0
|
5 |
+
annotated-types==0.7.0
|
6 |
anyio==3.6.2
|
7 |
async-timeout==4.0.2
|
8 |
attrs==22.1.0
|
|
|
15 |
contourpy==1.0.6
|
16 |
cryptography==38.0.3
|
17 |
cycler==0.11.0
|
18 |
+
fastapi==0.115.12
|
19 |
ffmpy==0.3.0
|
20 |
+
filelock==3.18.0
|
21 |
flatbuffers==22.10.26
|
22 |
fonttools==4.38.0
|
23 |
frozenlist==1.3.3
|
24 |
+
fsspec==2025.5.1
|
25 |
+
gradio==5.32.0
|
26 |
+
gradio_client==1.10.2
|
27 |
+
groovy==0.1.2
|
28 |
+
h11==0.16.0
|
29 |
+
hf-xet==1.1.2
|
30 |
+
httpcore==1.0.9
|
31 |
+
httpx==0.28.1
|
32 |
+
huggingface-hub==0.32.3
|
33 |
humanfriendly==10.0
|
34 |
idna==3.4
|
35 |
+
importlib_resources==6.5.2
|
36 |
Jinja2==3.1.2
|
37 |
+
jsonschema==4.17.3
|
38 |
kiwisolver==1.4.4
|
39 |
linkify-it-py==1.0.3
|
40 |
+
markdown-it-py==3.0.0
|
41 |
MarkupSafe==2.1.1
|
42 |
matplotlib==3.6.2
|
43 |
mdit-py-plugins==0.3.1
|
|
|
55 |
protobuf==4.21.9
|
56 |
pycparser==2.21
|
57 |
pycryptodome==3.15.0
|
58 |
+
pydantic==2.11.5
|
59 |
+
pydantic_core==2.33.2
|
60 |
pydub==0.25.1
|
61 |
+
Pygments==2.19.1
|
62 |
PyNaCl==1.5.0
|
63 |
pyparsing==3.0.9
|
64 |
+
pyrsistent==0.20.0
|
65 |
python-dateutil==2.8.2
|
66 |
+
python-multipart==0.0.20
|
67 |
pytz==2022.6
|
68 |
PyYAML==6.0
|
69 |
requests==2.28.1
|
70 |
rfc3986==1.5.0
|
71 |
+
rich==14.0.0
|
72 |
+
ruff==0.11.12
|
73 |
+
safehttpx==0.1.6
|
74 |
+
semantic-version==2.10.0
|
75 |
+
shellingham==1.5.4
|
76 |
six==1.16.0
|
77 |
sniffio==1.3.0
|
78 |
+
starlette==0.46.2
|
79 |
sympy==1.11.1
|
80 |
+
tomlkit==0.13.2
|
81 |
+
toolz==1.0.0
|
82 |
+
tqdm==4.67.1
|
83 |
+
typer==0.16.0
|
84 |
+
typing-inspection==0.4.1
|
85 |
+
typing_extensions==4.13.2
|
86 |
uc-micro-py==1.0.1
|
87 |
urllib3==1.26.12
|
88 |
uvicorn==0.19.0
|