gauthamk commited on
Commit
8c100dd
·
1 Parent(s): 00e18b7

refactor: code and upgrade gradio

Browse files
Files changed (3) hide show
  1. app.py +16 -4
  2. functions.py +2 -9
  3. requirements.txt +34 -11
app.py CHANGED
@@ -1,11 +1,23 @@
1
  import gradio as gr
2
- from functions import *
 
 
 
 
 
 
3
 
4
  title = "Water Body Segmentation - Image Segmentation PyTorch"
5
  examples = ['examples/image1.png', 'examples/image2.png', 'examples/image3.png', 'examples/image4.png', 'examples/image5.png']
6
 
7
- interface = gr.Interface(fn=predict, inputs=gr.Image(type= 'numpy').style(height= 256),
8
- outputs= gr.Image(type = "numpy").style(height= 256),
9
- examples= examples, title= title, css= '.gr-box {background-color: rgb(230 230 230);}')
 
 
 
 
 
 
10
 
11
  interface.launch()
 
1
  import gradio as gr
2
+ import onnxruntime as rt
3
+ from functions import run_inference
4
+
5
+ model_path = 'weights/model.onnx'
6
+ session = rt.InferenceSession(model_path)
7
+ input_name = session.get_inputs()[0].name
8
+ output_name = session.get_outputs()[0].name
9
 
10
  title = "Water Body Segmentation - Image Segmentation PyTorch"
11
  examples = ['examples/image1.png', 'examples/image2.png', 'examples/image3.png', 'examples/image4.png', 'examples/image5.png']
12
 
13
+ def inference_wrapper(image):
14
+ return run_inference(image, session, input_name, output_name)
15
+
16
+
17
+ interface = gr.Interface(fn=inference_wrapper,
18
+ inputs=gr.Image(type='numpy', height=400, width=400),
19
+ outputs=gr.Image(type="numpy", height=400, width=400),
20
+ examples=examples,
21
+ title=title)
22
 
23
  interface.launch()
functions.py CHANGED
@@ -1,6 +1,5 @@
1
  import cv2
2
  import numpy as np
3
- import onnxruntime as rt
4
 
5
  def resize_preserve_aspect_ratio(image, size):
6
  h, w = image.shape[:2]
@@ -10,9 +9,7 @@ def resize_preserve_aspect_ratio(image, size):
10
  image = cv2.resize(image, (size, size * h // w))
11
  return image
12
 
13
- def predict(inp_image):
14
-
15
- model_path = 'weights/model.onnx'
16
  inp_dim = inp_image.shape[:2]
17
 
18
  image = cv2.resize(inp_image, (256, 256))
@@ -20,10 +17,6 @@ def predict(inp_image):
20
  image = np.transpose(image, (2, 0, 1))
21
  image = np.expand_dims(image, axis=0)
22
 
23
- session = rt.InferenceSession(model_path)
24
- input_name = session.get_inputs()[0].name
25
- output_name = session.get_outputs()[0].name
26
-
27
  pred_onx = session.run([output_name], {input_name: image.astype(np.float32)})[0]
28
  pred_onx = pred_onx > 0.5
29
  pred_onx = pred_onx * 255
@@ -32,5 +25,5 @@ def predict(inp_image):
32
  pred_onx = np.expand_dims(pred_onx, axis=2)
33
  pred_onx = np.concatenate((pred_onx, pred_onx, pred_onx), axis=2)
34
 
35
- output = resize_preserve_aspect_ratio(pred_onx, 256)
36
  return output
 
1
  import cv2
2
  import numpy as np
 
3
 
4
  def resize_preserve_aspect_ratio(image, size):
5
  h, w = image.shape[:2]
 
9
  image = cv2.resize(image, (size, size * h // w))
10
  return image
11
 
12
+ def run_inference(inp_image, session, input_name, output_name):
 
 
13
  inp_dim = inp_image.shape[:2]
14
 
15
  image = cv2.resize(inp_image, (256, 256))
 
17
  image = np.transpose(image, (2, 0, 1))
18
  image = np.expand_dims(image, axis=0)
19
 
 
 
 
 
20
  pred_onx = session.run([output_name], {input_name: image.astype(np.float32)})[0]
21
  pred_onx = pred_onx > 0.5
22
  pred_onx = pred_onx * 255
 
25
  pred_onx = np.expand_dims(pred_onx, axis=2)
26
  pred_onx = np.concatenate((pred_onx, pred_onx, pred_onx), axis=2)
27
 
28
+ output = resize_preserve_aspect_ratio(pred_onx, 400)
29
  return output
requirements.txt CHANGED
@@ -1,5 +1,8 @@
 
1
  aiohttp==3.8.3
2
  aiosignal==1.3.1
 
 
3
  anyio==3.6.2
4
  async-timeout==4.0.2
5
  attrs==22.1.0
@@ -12,22 +15,29 @@ coloredlogs==15.0.1
12
  contourpy==1.0.6
13
  cryptography==38.0.3
14
  cycler==0.11.0
15
- fastapi==0.87.0
16
  ffmpy==0.3.0
 
17
  flatbuffers==22.10.26
18
  fonttools==4.38.0
19
  frozenlist==1.3.3
20
- fsspec==2022.11.0
21
- gradio==3.50.2
22
- h11==0.12.0
23
- httpcore==0.15.0
24
- httpx==0.23.0
 
 
 
 
25
  humanfriendly==10.0
26
  idna==3.4
 
27
  Jinja2==3.1.2
 
28
  kiwisolver==1.4.4
29
  linkify-it-py==1.0.3
30
- markdown-it-py==2.1.0
31
  MarkupSafe==2.1.1
32
  matplotlib==3.6.2
33
  mdit-py-plugins==0.3.1
@@ -45,21 +55,34 @@ Pillow==9.3.0
45
  protobuf==4.21.9
46
  pycparser==2.21
47
  pycryptodome==3.15.0
48
- pydantic==1.10.2
 
49
  pydub==0.25.1
 
50
  PyNaCl==1.5.0
51
  pyparsing==3.0.9
 
52
  python-dateutil==2.8.2
53
- python-multipart==0.0.5
54
  pytz==2022.6
55
  PyYAML==6.0
56
  requests==2.28.1
57
  rfc3986==1.5.0
 
 
 
 
 
58
  six==1.16.0
59
  sniffio==1.3.0
60
- starlette==0.21.0
61
  sympy==1.11.1
62
- typing_extensions==4.4.0
 
 
 
 
 
63
  uc-micro-py==1.0.1
64
  urllib3==1.26.12
65
  uvicorn==0.19.0
 
1
+ aiofiles==23.2.1
2
  aiohttp==3.8.3
3
  aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.7.0
6
  anyio==3.6.2
7
  async-timeout==4.0.2
8
  attrs==22.1.0
 
15
  contourpy==1.0.6
16
  cryptography==38.0.3
17
  cycler==0.11.0
18
+ fastapi==0.115.12
19
  ffmpy==0.3.0
20
+ filelock==3.18.0
21
  flatbuffers==22.10.26
22
  fonttools==4.38.0
23
  frozenlist==1.3.3
24
+ fsspec==2025.5.1
25
+ gradio==5.32.0
26
+ gradio_client==1.10.2
27
+ groovy==0.1.2
28
+ h11==0.16.0
29
+ hf-xet==1.1.2
30
+ httpcore==1.0.9
31
+ httpx==0.28.1
32
+ huggingface-hub==0.32.3
33
  humanfriendly==10.0
34
  idna==3.4
35
+ importlib_resources==6.5.2
36
  Jinja2==3.1.2
37
+ jsonschema==4.17.3
38
  kiwisolver==1.4.4
39
  linkify-it-py==1.0.3
40
+ markdown-it-py==3.0.0
41
  MarkupSafe==2.1.1
42
  matplotlib==3.6.2
43
  mdit-py-plugins==0.3.1
 
55
  protobuf==4.21.9
56
  pycparser==2.21
57
  pycryptodome==3.15.0
58
+ pydantic==2.11.5
59
+ pydantic_core==2.33.2
60
  pydub==0.25.1
61
+ Pygments==2.19.1
62
  PyNaCl==1.5.0
63
  pyparsing==3.0.9
64
+ pyrsistent==0.20.0
65
  python-dateutil==2.8.2
66
+ python-multipart==0.0.20
67
  pytz==2022.6
68
  PyYAML==6.0
69
  requests==2.28.1
70
  rfc3986==1.5.0
71
+ rich==14.0.0
72
+ ruff==0.11.12
73
+ safehttpx==0.1.6
74
+ semantic-version==2.10.0
75
+ shellingham==1.5.4
76
  six==1.16.0
77
  sniffio==1.3.0
78
+ starlette==0.46.2
79
  sympy==1.11.1
80
+ tomlkit==0.13.2
81
+ toolz==1.0.0
82
+ tqdm==4.67.1
83
+ typer==0.16.0
84
+ typing-inspection==0.4.1
85
+ typing_extensions==4.13.2
86
  uc-micro-py==1.0.1
87
  urllib3==1.26.12
88
  uvicorn==0.19.0