tristan-deep commited on
Commit
3cbb31e
·
1 Parent(s): 3dbb2a3

add gradio app

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +201 -0
  3. configs/slider_params.yaml +29 -0
  4. requirements.txt +6 -0
.gitignore CHANGED
@@ -6,4 +6,5 @@ temp/
6
  *.pdf
7
  *.hash
8
  *.npz
9
- sweep_results/
 
 
6
  *.pdf
7
  *.hash
8
  *.npz
9
+ sweep_results/
10
+ .gradio
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import gradio as gr
5
+ import jax
6
+ import numpy as np
7
+ import spaces
8
+ import zea
9
+ from PIL import Image
10
+
11
+ from main import Config, init, run
12
+
13
+ CONFIG_PATH = "configs/semantic_dps.yaml"
14
+ SLIDER_CONFIG_PATH = "configs/slider_params.yaml"
15
+ ASSETS_DIR = "assets"
16
+
17
+ description = """
18
+ # Semantic Diffusion Posterior Sampling for Cardiac Ultrasound Dehazing
19
+ Select an example image below. The algorithm will dehaze the image. Note that the algorithm was heavily tuned for the DehazingEcho2025 challenge dataset, and not optimized for generalization. Therefore it is not expected to work well on any type of echocardiogram.
20
+
21
+ Two parameters that are interesting to control and adjust the amount of dehazing are the "Omega (Ventricle)" and "Eta (haze prior)"
22
+ """
23
+
24
+
25
+ @spaces.GPU
26
+ def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
27
+ if input_img is None:
28
+ raise gr.Error(
29
+ "No input image was provided. Please select or upload an image before running."
30
+ )
31
+
32
+ def _prepare_image(image):
33
+ resized = False
34
+
35
+ if image.mode != "L":
36
+ image = image.convert("L")
37
+
38
+ orig_shape = image.size[::-1]
39
+ h, w = diffusion_model.input_shape[:2]
40
+ if image.size != (w, h):
41
+ image = image.resize((w, h), Image.BILINEAR)
42
+ resized = True
43
+
44
+ image = np.array(image)
45
+
46
+ image = image.astype(np.float32)
47
+ image = image[None, ...]
48
+ return image, resized, orig_shape
49
+
50
+ try:
51
+ image, resized, orig_shape = _prepare_image(input_img)
52
+ except Exception:
53
+ raise gr.Error("Something went wrong with preparing the input image.")
54
+
55
+ guidance_kwargs = {
56
+ "omega": omega,
57
+ "omega_vent": omega_vent,
58
+ "omega_sept": omega_sept,
59
+ "eta": eta,
60
+ "smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"],
61
+ }
62
+
63
+ seed = jax.random.PRNGKey(config.seed)
64
+
65
+ try:
66
+ _, pred_tissue_images, *_ = run(
67
+ hazy_images=image,
68
+ diffusion_model=diffusion_model,
69
+ seed=seed,
70
+ guidance_kwargs=guidance_kwargs,
71
+ mask_params=params["mask_params"],
72
+ fixed_mask_params=params["fixed_mask_params"],
73
+ skeleton_params=params["skeleton_params"],
74
+ batch_size=1,
75
+ diffusion_steps=diffusion_steps,
76
+ initial_diffusion_step=params.get("initial_diffusion_step", 0),
77
+ threshold_output_quantile=params.get("threshold_output_quantile", None),
78
+ preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0),
79
+ bottom_transition_width=params.get("bottom_transition_width", 10.0),
80
+ verbose=False,
81
+ )
82
+ except Exception:
83
+ raise gr.Error("The algorithm failed to process the image.")
84
+
85
+ out_img = np.squeeze(pred_tissue_images[0])
86
+ out_img = np.clip(out_img, 0, 255).astype(np.uint8)
87
+ out_pil = Image.fromarray(out_img)
88
+ # Resize back to original input size if needed
89
+ if resized and out_pil.size != (orig_shape[1], orig_shape[0]):
90
+ out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR)
91
+ # Return tuple for ImageSlider: (input, output)
92
+ return (input_img, out_pil)
93
+
94
+
95
+ slider_params = Config.from_yaml(SLIDER_CONFIG_PATH)
96
+
97
+ diffusion_steps_default = slider_params["diffusion_steps"]["default"]
98
+ diffusion_steps_min = slider_params["diffusion_steps"]["min"]
99
+ diffusion_steps_max = slider_params["diffusion_steps"]["max"]
100
+ diffusion_steps_step = slider_params["diffusion_steps"]["step"]
101
+
102
+ omega_default = slider_params["omega"]["default"]
103
+ omega_min = slider_params["omega"]["min"]
104
+ omega_max = slider_params["omega"]["max"]
105
+ omega_step = slider_params["omega"]["step"]
106
+
107
+ omega_vent_default = slider_params["omega_vent"]["default"]
108
+ omega_vent_min = slider_params["omega_vent"]["min"]
109
+ omega_vent_max = slider_params["omega_vent"]["max"]
110
+ omega_vent_step = slider_params["omega_vent"]["step"]
111
+
112
+ omega_sept_default = slider_params["omega_sept"]["default"]
113
+ omega_sept_min = slider_params["omega_sept"]["min"]
114
+ omega_sept_max = slider_params["omega_sept"]["max"]
115
+ omega_sept_step = slider_params["omega_sept"]["step"]
116
+
117
+ eta_default = slider_params["eta"]["default"]
118
+ eta_min = slider_params["eta"]["min"]
119
+ eta_max = slider_params["eta"]["max"]
120
+ eta_step = slider_params["eta"]["step"]
121
+
122
+
123
+ example_image_paths = [
124
+ os.path.join(ASSETS_DIR, f)
125
+ for f in os.listdir(ASSETS_DIR)
126
+ if f.lower().endswith(".png")
127
+ ]
128
+ example_images = [zea.io_lib.load_image(p) for p in example_image_paths]
129
+ examples = [[img] for img in example_images]
130
+
131
+
132
+ with gr.Blocks() as demo:
133
+ gr.Markdown(description)
134
+ status = gr.Markdown("Initializing model, please wait...", visible=True)
135
+ with gr.Row():
136
+ img1 = gr.Image(label="Input Image", type="pil", webcam_options=False)
137
+ img2 = gr.ImageSlider(label="Dehazed Image", type="pil")
138
+ gr.Examples(examples=examples, inputs=[img1])
139
+ with gr.Row():
140
+ diffusion_steps_slider = gr.Slider(
141
+ minimum=diffusion_steps_min,
142
+ maximum=diffusion_steps_max,
143
+ step=diffusion_steps_step,
144
+ value=diffusion_steps_default,
145
+ label="Diffusion Steps",
146
+ )
147
+ omega_slider = gr.Slider(
148
+ minimum=omega_min,
149
+ maximum=omega_max,
150
+ step=omega_step,
151
+ value=omega_default,
152
+ label="Omega (background)",
153
+ )
154
+ omega_vent_slider = gr.Slider(
155
+ minimum=omega_vent_min,
156
+ maximum=omega_vent_max,
157
+ step=omega_vent_step,
158
+ value=omega_vent_default,
159
+ label="Omega Ventricle",
160
+ )
161
+ omega_sept_slider = gr.Slider(
162
+ minimum=omega_sept_min,
163
+ maximum=omega_sept_max,
164
+ step=omega_sept_step,
165
+ value=omega_sept_default,
166
+ label="Omega Septum",
167
+ )
168
+ eta_slider = gr.Slider(
169
+ minimum=eta_min,
170
+ maximum=eta_max,
171
+ step=eta_step,
172
+ value=eta_default,
173
+ label="Eta (haze prior)",
174
+ )
175
+ run_btn = gr.Button("Run")
176
+
177
+ def initialize_model():
178
+ time.sleep(0.5) # Let UI update
179
+ config = Config.from_yaml(CONFIG_PATH)
180
+ diffusion_model = init(config)
181
+ params = config.params
182
+ return config, diffusion_model, params
183
+
184
+ config, diffusion_model, params = initialize_model()
185
+ status.visible = False
186
+
187
+ run_btn.click(
188
+ process_image,
189
+ inputs=[
190
+ img1,
191
+ diffusion_steps_slider,
192
+ omega_slider,
193
+ omega_vent_slider,
194
+ omega_sept_slider,
195
+ eta_slider,
196
+ ],
197
+ outputs=[img2],
198
+ )
199
+
200
+ if __name__ == "__main__":
201
+ demo.launch(share=True)
configs/slider_params.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusion_steps:
2
+ default: 300
3
+ min: 20
4
+ max: 500
5
+ step: 1
6
+
7
+ omega:
8
+ default: 1.0
9
+ min: 0.5
10
+ max: 3.0
11
+ step: 0.01
12
+
13
+ omega_vent:
14
+ default: 0.1
15
+ min: 0.0
16
+ max: 1.0
17
+ step: 0.01
18
+
19
+ omega_sept:
20
+ default: 2.0
21
+ min: 1.0
22
+ max: 5.0
23
+ step: 0.01
24
+
25
+ eta:
26
+ default: 0.00780
27
+ min: 0.0
28
+ max: 0.01
29
+ step: 0.00001
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ zea==0.0.4
2
+ jax[cuda12]
3
+ tyro
4
+ optuna
5
+ gradio
6
+ spaces