tristan-deep commited on
Commit
36e1539
·
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .devcontainer
2
+ .env
3
+ temp/
4
+ *.png
5
+ *.pdf
Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM zeahub/all:v0.0.4
2
+
3
+ RUN pip install --no-cache-dir SimpleITK tyro beautifulsoup4 tabulate optuna
4
+
5
+ RUN pip install --no-cache-dir --no-deps pytorch_fid
6
+
7
+ RUN pip install --no-cache-dir -U keras
8
+
9
+ WORKDIR /workspace
assets/patient-1-4C-frame-2.png ADDED
assets/patient-17-4C-frame-11.png ADDED
assets/patient-21-4C-frame-21.png ADDED
assets/patient-46-4C-frame-57.png ADDED
assets/patient-47-4C-frame-59.png ADDED
assets/patient-50-4C-frame-53.png ADDED
configs/semantic_dps.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusion_model_path: "hf://tristan-deep/semantic-diffusion-echo-dehazing"
2
+ segmentation_model_path: "hf://tristan-deep/semantic-segmentation-echo-dehazing"
3
+ seed: 42
4
+
5
+ params:
6
+ diffusion_steps: 480
7
+ initial_diffusion_step: 0
8
+ batch_size: 16
9
+ threshold_output_quantile: 0.17447
10
+ preserve_bottom_percent: 32.0
11
+ bottom_transition_width: 7.0
12
+
13
+ mask_params:
14
+ sigma: 4.2
15
+ threshold: 0.176
16
+ fixed_mask_params:
17
+ top_px: 20
18
+ bottom_px: 40
19
+ skeleton_params:
20
+ sigma_pre: 4.2
21
+ sigma_post: 4.2
22
+ threshold: 0.176
23
+ guidance_kwargs:
24
+ omega: 1
25
+ omega_vent: 0.3
26
+ omega_sept: 2.037
27
+ eta: 0.00780
28
+ smooth_l1_beta: 1.6355
eval.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from glob import glob
3
+ from pathlib import Path
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ import tyro
9
+ from PIL import Image
10
+ from scipy.ndimage import binary_erosion, distance_transform_edt
11
+ from scipy.stats import ks_2samp
12
+ from zea.io_lib import load_image
13
+
14
+ import fid_score
15
+
16
+
17
+ def calculate_fid_score(denoised_image_dirs, ground_truth_dir):
18
+ if isinstance(denoised_image_dirs, (str, Path)):
19
+ denoised_image_dirs = [denoised_image_dirs]
20
+ elif not isinstance(denoised_image_dirs, list):
21
+ raise ValueError("Input must be a path or list of paths")
22
+
23
+ clean_images_folder = glob(str(ground_truth_dir) + "/*.png")
24
+
25
+ print(f"Looking for clean images in: {ground_truth_dir}")
26
+ print(f"Found {len(clean_images_folder)} clean images")
27
+
28
+ # Determine optimal batch size based on number of images
29
+ num_denoised = len(denoised_image_dirs)
30
+ num_clean = len(clean_images_folder)
31
+ optimal_batch_size = min(8, num_denoised, num_clean)
32
+ print(f"Using batch size: {optimal_batch_size}")
33
+
34
+ with warnings.catch_warnings():
35
+ warnings.filterwarnings("ignore", message="os.fork.*JAX is multithreaded")
36
+
37
+ fid_value = fid_score.calculate_fid_with_cached_ground_truth(
38
+ denoised_image_dirs,
39
+ clean_images_folder,
40
+ batch_size=optimal_batch_size,
41
+ device="cuda" if torch.cuda.is_available() else "cpu",
42
+ num_workers=2 if torch.cuda.is_available() else 0,
43
+ dims=2048,
44
+ )
45
+ return fid_value
46
+
47
+
48
+ def gcnr(img1, img2):
49
+ """Generalized Contrast-to-Noise Ratio"""
50
+ _, bins = np.histogram(np.concatenate((img1, img2)), bins=256)
51
+ f, _ = np.histogram(img1, bins=bins, density=True)
52
+ g, _ = np.histogram(img2, bins=bins, density=True)
53
+ f /= f.sum()
54
+ g /= g.sum()
55
+ return 1 - np.sum(np.minimum(f, g))
56
+
57
+
58
+ def cnr(img1, img2):
59
+ """Contrast-to-Noise Ratio"""
60
+ return (img1.mean() - img2.mean()) / np.sqrt(img1.var() + img2.var())
61
+
62
+
63
+ def calculate_cnr_gcnr(result_dehazed_cardiac_ultrasound, mask_path):
64
+ """
65
+ Evaluate gCNR and CNR metrics for denoised images using paired masks.
66
+ Saves detailed and summary statistics to Excel.
67
+ """
68
+ results = []
69
+
70
+ mask = np.array(Image.open(mask_path).convert("L"))
71
+
72
+ roi1_pixels = result_dehazed_cardiac_ultrasound[mask == 255] # Foreground ROI
73
+ roi2_pixels = result_dehazed_cardiac_ultrasound[mask == 128] # Background/Noise ROI
74
+
75
+ gcnr_val = gcnr(roi1_pixels, roi2_pixels)
76
+ cnr_val = cnr(roi1_pixels, roi2_pixels)
77
+
78
+ results.append([cnr_val, gcnr_val])
79
+
80
+ return results
81
+
82
+
83
+ def calculate_ks_statistics(
84
+ result_hazy_cardiac_ultrasound, result_dehazed_cardiac_ultrasound, mask_path
85
+ ):
86
+ mask = np.array(Image.open(mask_path).convert("L"))
87
+
88
+ roi1_original = result_hazy_cardiac_ultrasound[mask == 255] # region A
89
+ roi1_denoised = result_dehazed_cardiac_ultrasound[mask == 255]
90
+ roi2_original = result_hazy_cardiac_ultrasound[mask == 128] # region B
91
+ roi2_denoised = result_dehazed_cardiac_ultrasound[mask == 128]
92
+
93
+ roi1_ks_stat, roi1_ks_p_value = (None, None)
94
+ roi2_ks_stat, roi2_ks_p_value = (None, None)
95
+
96
+ if roi1_original.size > 0 and roi1_denoised.size > 0:
97
+ roi1_ks_stat, roi1_ks_p_value = ks_2samp(roi1_original, roi1_denoised)
98
+
99
+ if roi2_original.size > 0 and roi2_denoised.size > 0:
100
+ roi2_ks_stat, roi2_ks_p_value = ks_2samp(roi2_original, roi2_denoised)
101
+
102
+ return roi1_ks_stat, roi1_ks_p_value, roi2_ks_stat, roi2_ks_p_value
103
+
104
+
105
+ def calculate_dice_asd(image_path, label_path, checkpoint_path, image_size=224):
106
+ try:
107
+ from test import inference # Our Segmentation Method
108
+ except ImportError:
109
+ raise ImportError(
110
+ "Segmentation method not available, skipping Dice/ASD calculation"
111
+ )
112
+
113
+ pred_img = inference(image_path, checkpoint_path, image_size)
114
+ pred = np.array(pred_img) > 127
115
+
116
+ label = Image.open(label_path).convert("L")
117
+ label = label.resize((image_size, image_size), Image.NEAREST)
118
+ label = np.array(label) > 127
119
+
120
+ # calculate Dice
121
+ intersection = np.logical_and(pred, label).sum()
122
+ dice = 2 * intersection / (pred.sum() + label.sum() + 1e-8)
123
+
124
+ # calculate ASD
125
+ if pred.sum() == 0 or label.sum() == 0:
126
+ asd = np.nan
127
+ else:
128
+ pred_dt = distance_transform_edt(~pred)
129
+ label_dt = distance_transform_edt(~label)
130
+
131
+ surface_pred = pred ^ binary_erosion(pred)
132
+ surface_label = label ^ binary_erosion(label)
133
+
134
+ d1 = pred_dt[surface_label].mean()
135
+ d2 = label_dt[surface_pred].mean()
136
+ asd = (d1 + d2) / 2
137
+
138
+ return dice, asd
139
+
140
+
141
+ def calculate_final_score(aggregates):
142
+ try:
143
+ # (FID + CNR + gCNR):(KS^A + KS^B):(Dice + ASD)= 5:3:2
144
+
145
+ group1_score = 0 # FID + CNR + gCNR
146
+ if aggregates.get("fid") is not None:
147
+ fid_min = 60.0
148
+ fid_max = 150.0
149
+ fid_score = (fid_max - aggregates["fid"]) / (fid_max - fid_min)
150
+ fid_score = max(0, min(1, fid_score))
151
+ group1_score += fid_score * 100 * 0.33
152
+
153
+ if aggregates.get("cnr_mean") is not None:
154
+ cnr_min = 1.0
155
+ cnr_max = 1.5
156
+ cnr_score = (aggregates["cnr_mean"] - cnr_min) / (cnr_max - cnr_min)
157
+ cnr_score = max(0, min(1, cnr_score))
158
+ group1_score += cnr_score * 100 * 0.33
159
+
160
+ if aggregates.get("gcnr_mean") is not None:
161
+ gcnr_min = 0.5
162
+ gcnr_max = 0.8
163
+ gcnr_score = (aggregates["gcnr_mean"] - gcnr_min) / (gcnr_max - gcnr_min)
164
+ gcnr_score = max(0, min(1, gcnr_score))
165
+ group1_score += gcnr_score * 100 * 0.34
166
+
167
+ group2_score = 0 # KS^A + KS^B
168
+ if aggregates.get("ks_roi1_ksstatistic_mean") is not None:
169
+ ks1_min = 0.1
170
+ ks1_max = 0.3
171
+ ks1_score = (ks1_max - aggregates["ks_roi1_ksstatistic_mean"]) / (
172
+ ks1_max - ks1_min
173
+ )
174
+ ks1_score = max(0, min(1, ks1_score))
175
+ group2_score += ks1_score * 100 * 0.5
176
+
177
+ if aggregates.get("ks_roi2_ksstatistic_mean") is not None:
178
+ ks2_min = 0.0
179
+ ks2_max = 0.5
180
+ ks2_score = (aggregates["ks_roi2_ksstatistic_mean"] - ks2_min) / (
181
+ ks2_max - ks2_min
182
+ )
183
+ ks2_score = max(0, min(1, ks2_score))
184
+ group2_score += ks2_score * 100 * 0.5
185
+
186
+ group3_score = 0 # Dice + ASD
187
+ if aggregates.get("dice_mean") is not None:
188
+ dice_min = 0.85
189
+ dice_max = 0.95
190
+ dice_score = (aggregates["dice_mean"] - dice_min) / (dice_max - dice_min)
191
+ dice_score = max(0, min(1, dice_score))
192
+ group3_score += dice_score * 100 * 0.5
193
+ if aggregates.get("asd_mean") is not None:
194
+ asd_min = 0.7
195
+ asd_max = 2.0
196
+ asd_score = (asd_max - aggregates["asd_mean"]) / (asd_max - asd_min)
197
+ asd_score = max(0, min(1, asd_score))
198
+ group3_score += asd_score * 100 * 0.5
199
+
200
+ # Final score calculation
201
+ final_score = (group1_score * 5 + group2_score * 3 + group3_score * 2) / 10
202
+
203
+ return final_score
204
+
205
+ except Exception as e:
206
+ print(f"Error calculating final score: {str(e)}")
207
+ return 0
208
+
209
+
210
+ def plot_metrics(metrics, limits, out_path):
211
+ plt.style.use("seaborn-v0_8-darkgrid")
212
+ fig, axes = plt.subplots(1, len(metrics), figsize=(7.2, 2.7), dpi=600)
213
+ colors = ["#0057b7", "#ffb300", "#008744", "#d62d20"]
214
+ # Arrow direction: ↑ for up, ↓ for down
215
+ metric_labels = {
216
+ "CNR": r"CNR $\uparrow$",
217
+ "gCNR": r"gCNR $\uparrow$",
218
+ "KS_A": r"KS$_{septum}$ $\downarrow$",
219
+ "KS_B": r"KS$_{ventricle}$ $\uparrow$",
220
+ }
221
+ for idx, (ax, (name, values)) in enumerate(zip(axes, metrics.items())):
222
+ ax.hist(
223
+ values,
224
+ bins=30,
225
+ color=colors[idx % len(colors)],
226
+ alpha=0.85,
227
+ edgecolor="black",
228
+ linewidth=0.7,
229
+ )
230
+ ax.set_xlabel(metric_labels.get(name, name), fontsize=11)
231
+ ax.set_ylabel("Count", fontsize=10)
232
+ # Draw limits
233
+ if name in limits:
234
+ for lim in limits[name]:
235
+ ax.axvline(lim, color="crimson", linestyle="--", lw=1.2)
236
+ ax.spines["top"].set_visible(False)
237
+ ax.spines["right"].set_visible(False)
238
+ ax.tick_params(axis="both", which="major", labelsize=9)
239
+ fig.tight_layout(pad=1.5)
240
+ fig.savefig(out_path, bbox_inches="tight", dpi=600)
241
+ plt.close(fig)
242
+
243
+
244
+ def main(folder: str, roi_folder: str, reference_folder: str):
245
+ folder = Path(folder)
246
+ roi_folder = Path(roi_folder)
247
+ reference_folder = Path(reference_folder)
248
+
249
+ folder_files = set(f.name for f in folder.glob("*.png"))
250
+ roi_files = set(f.name for f in roi_folder.glob("*.png"))
251
+ ref_files = set(f.name for f in reference_folder.glob("*.png"))
252
+
253
+ print(f"Found {len(folder_files)} .png files in output folder: {folder}")
254
+ print(f"Found {len(roi_files)} .png files in ROI folder: {roi_folder}")
255
+ print(f"Found {len(ref_files)} .png files in reference folder: {reference_folder}")
256
+
257
+ # Find intersection of filenames
258
+ common_files = sorted(folder_files & roi_files & ref_files)
259
+ print(f"Found {len(common_files)} images present in all folders.")
260
+ if len(common_files) == 0:
261
+ print("No matching images found in all folders. Check your folder contents.")
262
+ print(f"Output folder files: {sorted(folder_files)}")
263
+ print(f"ROI folder files: {sorted(roi_files)}")
264
+ print(f"Reference folder files: {sorted(ref_files)}")
265
+ assert len(common_files) > 0, (
266
+ "No matching .png files in all folders. Cannot proceed."
267
+ )
268
+
269
+ metrics = {"CNR": [], "gCNR": [], "KS_A": [], "KS_B": []}
270
+ limits = {
271
+ "CNR": [1.0, 1.5],
272
+ "gCNR": [0.5, 0.8],
273
+ "KS_A": [0.1, 0.3],
274
+ "KS_B": [0.0, 0.5],
275
+ }
276
+
277
+ for name in common_files:
278
+ our_path = folder / name
279
+ roi_path = roi_folder / name
280
+ ref_path = reference_folder / name
281
+
282
+ assert our_path.exists(), f"Missing file in output folder: {our_path}"
283
+ assert roi_path.exists(), f"Missing file in ROI folder: {roi_path}"
284
+ assert ref_path.exists(), f"Missing file in reference folder: {ref_path}"
285
+
286
+ try:
287
+ img = np.array(load_image(str(our_path)))
288
+ img_ref = np.array(load_image(str(ref_path)))
289
+ except Exception as e:
290
+ print(f"Error loading image {name}: {e}")
291
+ continue
292
+
293
+ # CNR/gCNR
294
+ cnr_gcnr = calculate_cnr_gcnr(img, str(roi_path))
295
+ metrics["CNR"].append(cnr_gcnr[0][0])
296
+ metrics["gCNR"].append(cnr_gcnr[0][1])
297
+
298
+ # KS statistics
299
+ ks_a, _, ks_b, _ = calculate_ks_statistics(img_ref, img, str(roi_path))
300
+ metrics["KS_A"].append(ks_a)
301
+ metrics["KS_B"].append(ks_b)
302
+
303
+ # Compute statistics
304
+ stats = {
305
+ k: (np.mean(v), np.std(v), np.min(v), np.max(v)) for k, v in metrics.items()
306
+ }
307
+ print("Contrast statistics:")
308
+ for k, (mean, std, minv, maxv) in stats.items():
309
+ print(f"{k}: mean={mean:.3f}, std={std:.3f}, min={minv:.3f}, max={maxv:.3f}")
310
+
311
+ plot_metrics(metrics, limits, str(folder / "contrast_metrics.png"))
312
+ print(f"Saved metrics plot to {folder / 'contrast_metrics.png'}")
313
+
314
+ # Compute FID
315
+ fid_image_paths = [str(folder / name) for name in common_files]
316
+ fid_score = calculate_fid_score(fid_image_paths, str(reference_folder))
317
+ print(f"FID between {folder} and {reference_folder}: {fid_score:.3f}")
318
+
319
+
320
+ if __name__ == "__main__":
321
+ tyro.cli(main)
fid_score.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+
35
+ import hashlib
36
+ import os
37
+ import pathlib
38
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
39
+
40
+ import numpy as np
41
+ import torch
42
+ import torchvision.transforms as TF
43
+ from PIL import Image
44
+ from scipy import linalg
45
+ from torch.nn.functional import adaptive_avg_pool2d
46
+
47
+ try:
48
+ from tqdm import tqdm
49
+ except ImportError:
50
+ # If tqdm is not available, provide a mock version of it
51
+ def tqdm(x):
52
+ return x
53
+
54
+
55
+ from pytorch_fid.inception import InceptionV3
56
+
57
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
58
+ parser.add_argument("--batch-size", type=int, default=50, help="Batch size to use")
59
+ parser.add_argument(
60
+ "--num-workers",
61
+ type=int,
62
+ help=(
63
+ "Number of processes to use for data loading. Defaults to `min(8, num_cpus)`"
64
+ ),
65
+ )
66
+ parser.add_argument(
67
+ "--device", type=str, default=None, help="Device to use. Like cuda, cuda:0 or cpu"
68
+ )
69
+ parser.add_argument(
70
+ "--dims",
71
+ type=int,
72
+ default=2048,
73
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
74
+ help=(
75
+ "Dimensionality of Inception features to use. By default, uses pool3 features"
76
+ ),
77
+ )
78
+ parser.add_argument(
79
+ "--save-stats",
80
+ action="store_true",
81
+ help=(
82
+ "Generate an npz archive from a directory of samples. "
83
+ "The first path is used as input and the second as output."
84
+ ),
85
+ )
86
+ parser.add_argument(
87
+ "path",
88
+ type=str,
89
+ nargs=2,
90
+ help=("Paths to the generated images or to .npz statistic files"),
91
+ )
92
+
93
+ IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
94
+
95
+
96
+ class ImagePathDataset(torch.utils.data.Dataset):
97
+ def __init__(self, files, transforms=None):
98
+ self.files = files
99
+ self.transforms = transforms
100
+
101
+ def __len__(self):
102
+ return len(self.files)
103
+
104
+ def __getitem__(self, i):
105
+ path = self.files[i]
106
+ img = Image.open(path).convert("RGB")
107
+ if self.transforms is not None:
108
+ img = self.transforms(img)
109
+ return img
110
+
111
+
112
+ def get_activations(
113
+ files, model, batch_size=50, dims=2048, device="cpu", num_workers=1
114
+ ):
115
+ """Calculates the activations of the pool_3 layer for all images.
116
+
117
+ Params:
118
+ -- files : List of image files paths
119
+ -- model : Instance of inception model
120
+ -- batch_size : Batch size of images for the model to process at once.
121
+ Make sure that the number of samples is a multiple of
122
+ the batch size, otherwise some samples are ignored. This
123
+ behavior is retained to match the original FID score
124
+ implementation.
125
+ -- dims : Dimensionality of features returned by Inception
126
+ -- device : Device to run calculations
127
+ -- num_workers : Number of parallel dataloader workers
128
+
129
+ Returns:
130
+ -- A numpy array of dimension (num images, dims) that contains the
131
+ activations of the given tensor when feeding inception with the
132
+ query tensor.
133
+ """
134
+ model.eval()
135
+
136
+ if batch_size > len(files):
137
+ print(
138
+ (
139
+ "Warning: batch size is bigger than the data size. "
140
+ "Setting batch size to data size"
141
+ )
142
+ )
143
+ batch_size = len(files)
144
+ # print(files)
145
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
146
+ dataloader = torch.utils.data.DataLoader(
147
+ dataset,
148
+ batch_size=batch_size,
149
+ shuffle=False,
150
+ drop_last=False,
151
+ num_workers=num_workers,
152
+ )
153
+
154
+ pred_arr = np.empty((len(files), dims))
155
+
156
+ start_idx = 0
157
+
158
+ for batch in tqdm(dataloader):
159
+ batch = batch.to(device)
160
+
161
+ with torch.no_grad():
162
+ pred = model(batch)[0]
163
+
164
+ # If model output is not scalar, apply global spatial average pooling.
165
+ # This happens if you choose a dimensionality not equal 2048.
166
+ if pred.size(2) != 1 or pred.size(3) != 1:
167
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
168
+
169
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
170
+
171
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
172
+
173
+ start_idx = start_idx + pred.shape[0]
174
+
175
+ return pred_arr
176
+
177
+
178
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
179
+ """Numpy implementation of the Frechet Distance.
180
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
181
+ and X_2 ~ N(mu_2, C_2) is
182
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
183
+
184
+ Stable version by Dougal J. Sutherland.
185
+
186
+ Params:
187
+ -- mu1 : Numpy array containing the activations of a layer of the
188
+ inception net (like returned by the function 'get_predictions')
189
+ for generated samples.
190
+ -- mu2 : The sample mean over activations, precalculated on an
191
+ representative data set.
192
+ -- sigma1: The covariance matrix over activations for generated samples.
193
+ -- sigma2: The covariance matrix over activations, precalculated on an
194
+ representative data set.
195
+
196
+ Returns:
197
+ -- : The Frechet Distance.
198
+ """
199
+
200
+ mu1 = np.atleast_1d(mu1)
201
+ mu2 = np.atleast_1d(mu2)
202
+
203
+ sigma1 = np.atleast_2d(sigma1)
204
+ sigma2 = np.atleast_2d(sigma2)
205
+
206
+ assert mu1.shape == mu2.shape, (
207
+ "Training and test mean vectors have different lengths"
208
+ )
209
+ assert sigma1.shape == sigma2.shape, (
210
+ "Training and test covariances have different dimensions"
211
+ )
212
+
213
+ diff = mu1 - mu2
214
+
215
+ # Product might be almost singular
216
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
217
+ if not np.isfinite(covmean).all():
218
+ msg = (
219
+ "fid calculation produces singular product; "
220
+ "adding %s to diagonal of cov estimates"
221
+ ) % eps
222
+ print(msg)
223
+ offset = np.eye(sigma1.shape[0]) * eps
224
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
225
+
226
+ # Numerical error might give slight imaginary component
227
+ if np.iscomplexobj(covmean):
228
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
229
+ m = np.max(np.abs(covmean.imag))
230
+ raise ValueError("Imaginary component {}".format(m))
231
+ covmean = covmean.real
232
+
233
+ tr_covmean = np.trace(covmean)
234
+
235
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
236
+
237
+
238
+ def calculate_activation_statistics(
239
+ files, model, batch_size=50, dims=2048, device="cpu", num_workers=1
240
+ ):
241
+ """Calculation of the statistics used by the FID.
242
+ Params:
243
+ -- files : List of image files paths
244
+ -- model : Instance of inception model
245
+ -- batch_size : The images numpy array is split into batches with
246
+ batch size batch_size. A reasonable batch size
247
+ depends on the hardware.
248
+ -- dims : Dimensionality of features returned by Inception
249
+ -- device : Device to run calculations
250
+ -- num_workers : Number of parallel dataloader workers
251
+
252
+ Returns:
253
+ -- mu : The mean over samples of the activations of the pool_3 layer of
254
+ the inception model.
255
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
256
+ the inception model.
257
+ """
258
+ act = get_activations(files, model, batch_size, dims, device, num_workers)
259
+ mu = np.mean(act, axis=0)
260
+ sigma = np.cov(act, rowvar=False)
261
+ return mu, sigma
262
+
263
+
264
+ def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=1):
265
+ # if path.endswith('.npz'):
266
+ # with np.load(path) as f:
267
+ # m, s = f['mu'][:], f['sigma'][:]
268
+ # else:
269
+ m, s = calculate_activation_statistics(
270
+ path, model, batch_size, dims, device, num_workers
271
+ )
272
+ # else:
273
+ # path = pathlib.Path(path)
274
+ # files = sorted([file for ext in IMAGE_EXTENSIONS
275
+ # for file in path.glob('*.{}'.format(ext))])
276
+ # m, s = calculate_activation_statistics(files, model, batch_size,
277
+ # dims, device, num_workers)
278
+
279
+ return m, s
280
+
281
+
282
+ def _fid_cache_paths():
283
+ tmp_dir = pathlib.Path("tmp")
284
+ tmp_dir.mkdir(exist_ok=True)
285
+ stats_path = tmp_dir / "fid_stats.npz"
286
+ hash_path = tmp_dir / "fid_stats.hash"
287
+ return stats_path, hash_path
288
+
289
+
290
+ def _load_fid_stats(stats_path):
291
+ arr = np.load(stats_path)
292
+ return arr["mu"], arr["sigma"]
293
+
294
+
295
+ def _save_fid_stats(stats_path, mu, sigma):
296
+ np.savez_compressed(stats_path, mu=mu, sigma=sigma)
297
+
298
+
299
+ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
300
+ """Calculates the FID of two paths, with caching for ground truth stats if the second path is a directory of images."""
301
+ import pathlib
302
+
303
+ if isinstance(paths[1], (str, pathlib.Path)) and pathlib.Path(paths[1]).is_dir():
304
+ # Get all PNGs in the directory
305
+ gt_images = list(pathlib.Path(paths[1]).glob("*.png"))
306
+ stats_path, hash_path = _fid_cache_paths()
307
+ if stats_path.exists():
308
+ print(f"Using cached FID stats from {stats_path}")
309
+ print("WARNING: Cache may be stale if ground truth images have changed.")
310
+ m1, s1 = _load_fid_stats(stats_path)
311
+ else:
312
+ print("Computing FID stats for ground truth images...")
313
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
314
+ model = InceptionV3([block_idx]).to(device)
315
+ m1, s1 = calculate_activation_statistics(
316
+ gt_images, model, batch_size, dims, device, num_workers
317
+ )
318
+ _save_fid_stats(stats_path, m1, s1)
319
+ # m2, s2 for denoised images
320
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
321
+ model = InceptionV3([block_idx]).to(device)
322
+ m2, s2 = calculate_activation_statistics(
323
+ paths[0], model, batch_size, dims, device, num_workers
324
+ )
325
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
326
+ return fid_value
327
+
328
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
329
+ model = InceptionV3([block_idx]).to(device)
330
+ m1, s1 = compute_statistics_of_path(
331
+ paths[0], model, batch_size, dims, device, num_workers
332
+ )
333
+ print(paths[1])
334
+ m2, s2 = compute_statistics_of_path(
335
+ paths[1], model, batch_size, dims, device, num_workers
336
+ )
337
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
338
+ return fid_value
339
+
340
+
341
+ def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
342
+ """Calculates the FID of two paths"""
343
+ if not os.path.exists(paths[0]):
344
+ raise RuntimeError("Invalid path: %s" % paths[0])
345
+
346
+ if os.path.exists(paths[1]):
347
+ raise RuntimeError("Existing output file: %s" % paths[1])
348
+
349
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
350
+
351
+ model = InceptionV3([block_idx]).to(device)
352
+
353
+ print(f"Saving statistics for {paths[0]}")
354
+
355
+ m1, s1 = compute_statistics_of_path(
356
+ paths[0], model, batch_size, dims, device, num_workers
357
+ )
358
+
359
+ np.savez_compressed(paths[1], mu=m1, sigma=s1)
360
+
361
+
362
+ def calculate_fid_with_cached_ground_truth(
363
+ denoised_image_dirs,
364
+ ground_truth_image_dirs,
365
+ batch_size=50,
366
+ device="cpu",
367
+ dims=2048,
368
+ num_workers=1,
369
+ ):
370
+ """
371
+ Calculates the FID between denoised images and ground truth images, using cached stats for ground truth if possible.
372
+ Args:
373
+ denoised_image_dirs: list of denoised image paths
374
+ ground_truth_image_dirs: list of ground truth image paths (or a directory)
375
+ batch_size, device, dims, num_workers: same as calculate_fid_given_paths
376
+ Returns:
377
+ FID value
378
+ """
379
+ # If ground_truth_image_dirs is a directory, get all PNGs
380
+ if isinstance(ground_truth_image_dirs, (str, pathlib.Path)):
381
+ ground_truth_image_dirs = list(
382
+ pathlib.Path(ground_truth_image_dirs).glob("*.png")
383
+ )
384
+
385
+ # Compute hash for cache
386
+ def compute_file_hashes(file_list):
387
+ hash_md5 = hashlib.md5()
388
+ for fname in sorted(map(str, file_list)):
389
+ try:
390
+ stat = os.stat(fname)
391
+ hash_md5.update(fname.encode())
392
+ hash_md5.update(str(stat.st_mtime).encode())
393
+ except Exception:
394
+ continue
395
+ return hash_md5.hexdigest()
396
+
397
+ tmp_dir = pathlib.Path("tmp")
398
+ tmp_dir.mkdir(exist_ok=True)
399
+ stats_path = tmp_dir / "fid_stats.npz"
400
+ hash_path = tmp_dir / "fid_stats.hash"
401
+ # TODO: caching shouldn't be based on ground truth image dirs
402
+ # since we can have multiple reconstructions of same ground truth
403
+ current_hash = compute_file_hashes(ground_truth_image_dirs)
404
+ cache_valid = False
405
+ if stats_path.exists() and hash_path.exists():
406
+ try:
407
+ with open(hash_path, "r") as f:
408
+ cached_hash = f.read().strip()
409
+ if cached_hash == current_hash:
410
+ cache_valid = True
411
+ except Exception:
412
+ pass
413
+ # TODO: need more sophisticated caching for sweeps
414
+ if cache_valid:
415
+ print(f"Using cached FID stats from {stats_path}")
416
+ arr = np.load(stats_path)
417
+ mu, sigma = arr["mu"], arr["sigma"]
418
+ else:
419
+ print("Computing FID stats for ground truth images...")
420
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
421
+ model = InceptionV3([block_idx]).to(device)
422
+ mu, sigma = calculate_activation_statistics(
423
+ ground_truth_image_dirs,
424
+ model,
425
+ batch_size=batch_size,
426
+ dims=dims,
427
+ device=device,
428
+ num_workers=num_workers,
429
+ )
430
+ np.savez_compressed(stats_path, mu=mu, sigma=sigma)
431
+ with open(hash_path, "w") as f:
432
+ f.write(current_hash)
433
+ # Compute stats for denoised images
434
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
435
+ model = InceptionV3([block_idx]).to(device)
436
+ mu2, sigma2 = calculate_activation_statistics(
437
+ denoised_image_dirs,
438
+ model,
439
+ batch_size=batch_size,
440
+ dims=dims,
441
+ device=device,
442
+ num_workers=num_workers,
443
+ )
444
+ fid_value = calculate_frechet_distance(mu, sigma, mu2, sigma2)
445
+ return fid_value
446
+
447
+
448
+ def main():
449
+ args = parser.parse_args()
450
+
451
+ if args.device is None:
452
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
453
+ else:
454
+ device = torch.device(args.device)
455
+
456
+ if args.num_workers is None:
457
+ try:
458
+ num_cpus = len(os.sched_getaffinity(0))
459
+ except AttributeError:
460
+ # os.sched_getaffinity is not available under Windows, use
461
+ # os.cpu_count instead (which may not return the *available* number
462
+ # of CPUs).
463
+ num_cpus = os.cpu_count()
464
+
465
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
466
+ else:
467
+ num_workers = args.num_workers
468
+
469
+ if args.save_stats:
470
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
471
+ return
472
+
473
+ fid_value = calculate_fid_given_paths(
474
+ args.path, args.batch_size, device, args.dims, num_workers
475
+ )
476
+ print("FID: ", fid_value)
477
+
478
+
479
+ if __name__ == "__main__":
480
+ main()
main.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from pathlib import Path
4
+
5
+ os.environ["KERAS_BACKEND"] = "jax"
6
+
7
+ import jax
8
+ import keras
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import scipy
12
+ import tyro
13
+ import zea
14
+ from keras import ops
15
+ from matplotlib.patches import PathPatch
16
+ from matplotlib.path import Path as pltPath
17
+ from PIL import Image
18
+ from skimage import filters, measure, morphology
19
+ from zea import Config, init_device, log
20
+ from zea.internal.operators import Operator
21
+ from zea.models.diffusion import (
22
+ DPS,
23
+ DiffusionModel,
24
+ diffusion_guidance_registry,
25
+ )
26
+ from zea.tensor_ops import L2
27
+ from zea.utils import translate
28
+ from zea.visualize import plot_image_grid
29
+
30
+
31
+ def L1(x):
32
+ """L1 norm of a tensor.
33
+
34
+ Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
35
+ """
36
+ return ops.sum(ops.abs(x))
37
+
38
+
39
+ def smooth_L1(x, beta=0.4):
40
+ """Smooth L1 loss function.
41
+
42
+ Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss,
43
+ while small beta values make it similar to L2 loss.
44
+ """
45
+ abs_x = ops.abs(x)
46
+ loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta)
47
+ return ops.sum(loss)
48
+
49
+
50
+ def postprocess(data, normalization_range):
51
+ """Postprocess data from model output to image."""
52
+ data = ops.clip(data, *normalization_range)
53
+ data = translate(data, normalization_range, (0, 255))
54
+ data = ops.convert_to_numpy(data)
55
+ data = np.squeeze(data, axis=-1)
56
+ return np.clip(data, 0, 255).astype("uint8")
57
+
58
+
59
+ def preprocess(data, normalization_range):
60
+ """Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range."""
61
+ data = ops.convert_to_tensor(data, dtype="float32")
62
+ data = translate(data, (0, 255), normalization_range)
63
+ data = ops.expand_dims(data, axis=-1)
64
+ return data
65
+
66
+
67
+ def apply_bottom_preservation(
68
+ output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0
69
+ ):
70
+ """Apply bottom preservation with smooth windowed transition.
71
+
72
+ Args:
73
+ output_images: Model output images, (batch, height, width, channels)
74
+ input_images: Original input images, (batch, height, width, channels)
75
+ preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%)
76
+ transition_width: Percentage of image height for smooth transition (default 10%)
77
+
78
+ Returns:
79
+ Blended images with preserved bottom portion
80
+ """
81
+ output_shape = ops.shape(output_images)
82
+
83
+ batch_size, height, width, channels = output_shape
84
+
85
+ preserve_height = int(height * preserve_bottom_percent / 100.0)
86
+ transition_height = int(height * transition_width / 100.0)
87
+
88
+ transition_start = height - preserve_height - transition_height
89
+ preserve_start = height - preserve_height
90
+
91
+ transition_start = max(0, transition_start)
92
+ preserve_start = min(height, preserve_start)
93
+
94
+ if transition_start >= preserve_start:
95
+ transition_start = preserve_start
96
+ transition_height = 0
97
+
98
+ y_coords = ops.arange(height, dtype="float32")
99
+ y_coords = ops.reshape(y_coords, (height, 1, 1))
100
+
101
+ if transition_height > 0:
102
+ # Smooth transition using cosine interpolation
103
+ transition_region = ops.logical_and(
104
+ y_coords >= transition_start, y_coords < preserve_start
105
+ )
106
+
107
+ transition_progress = (y_coords - transition_start) / transition_height
108
+ transition_progress = ops.clip(transition_progress, 0.0, 1.0)
109
+
110
+ # Use cosine for smooth transition (0.5 * (1 - cos(π * t)))
111
+ cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress))
112
+
113
+ blend_weight = ops.where(
114
+ y_coords < transition_start,
115
+ 0.0,
116
+ ops.where(
117
+ transition_region,
118
+ cosine_weight,
119
+ 1.0,
120
+ ),
121
+ )
122
+ else:
123
+ # No transition, just hard switch
124
+ blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0)
125
+
126
+ blend_weight = ops.expand_dims(blend_weight, axis=0)
127
+
128
+ blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images
129
+
130
+ return blended_images
131
+
132
+
133
+ def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3):
134
+ """Extract skeletons from the input images."""
135
+ images_np = ops.convert_to_numpy(images)
136
+ images_np = np.clip(images_np, input_range[0], input_range[1])
137
+ images_np = translate(images_np, input_range, (0, 1))
138
+ images_np = np.squeeze(images_np, axis=-1)
139
+
140
+ skeleton_masks = []
141
+ for img in images_np:
142
+ img[img < threshold] = 0
143
+ smoothed = filters.gaussian(img, sigma=sigma_pre)
144
+ binary = smoothed > filters.threshold_otsu(smoothed)
145
+ skeleton = morphology.skeletonize(binary)
146
+ skeleton = morphology.dilation(skeleton, morphology.disk(2))
147
+ skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post)
148
+ skeleton_masks.append(skeleton)
149
+
150
+ skeleton_masks = np.array(skeleton_masks)
151
+ skeleton_masks = np.expand_dims(skeleton_masks, axis=-1)
152
+
153
+ # normalize to [0, 1]
154
+ min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks)
155
+ skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8)
156
+
157
+ return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype)
158
+
159
+
160
+ class IdentityOperator(Operator):
161
+ def forward(self, data):
162
+ return data
163
+
164
+ def __str__(self):
165
+ return "y = x"
166
+
167
+
168
+ @diffusion_guidance_registry(name="semantic_dps")
169
+ class SemanticDPS(DPS):
170
+ def __init__(
171
+ self,
172
+ diffusion_model,
173
+ segmentation_model,
174
+ operator,
175
+ disable_jit=False,
176
+ **kwargs,
177
+ ):
178
+ """Initialize the diffusion guidance.
179
+
180
+ Args:
181
+ diffusion_model: The diffusion model to use for guidance.
182
+ operator: The forward (measurement) operator to use for guidance.
183
+ disable_jit: Whether to disable JIT compilation.
184
+ """
185
+ self.diffusion_model = diffusion_model
186
+ self.segmentation_model = segmentation_model
187
+ self.operator = operator
188
+ self.disable_jit = disable_jit
189
+ self.setup(**kwargs)
190
+
191
+ def _get_fixed_mask(
192
+ self,
193
+ images,
194
+ bottom_px=40,
195
+ top_px=20,
196
+ ):
197
+ batch_size, height, width, channels = ops.shape(images)
198
+
199
+ # Create row indices for each pixel
200
+ row_indices = ops.arange(height)
201
+ row_indices = ops.reshape(row_indices, (height, 1))
202
+ row_indices = ops.tile(row_indices, (1, width))
203
+
204
+ # Create top row mask
205
+ fixed_mask = ops.where(
206
+ ops.logical_or(row_indices < top_px, row_indices >= height - bottom_px),
207
+ 1.0,
208
+ 0.0,
209
+ )
210
+ fixed_mask = ops.expand_dims(fixed_mask, axis=0)
211
+ fixed_mask = ops.expand_dims(fixed_mask, axis=-1)
212
+ fixed_mask = ops.tile(fixed_mask, (batch_size, 1, 1, channels))
213
+
214
+ return fixed_mask
215
+
216
+ def _get_segmentation_mask(self, images, threshold, sigma):
217
+ input_range = self.diffusion_model.input_range
218
+ images = ops.clip(images, input_range[0], input_range[1])
219
+ images = translate(images, input_range, (-1, 1))
220
+
221
+ masks = self.segmentation_model(images)
222
+ mask_vent = masks[..., 0] # ROI 1 ventricle
223
+ mask_sept = masks[..., 1] # ROI 2 septum
224
+
225
+ def _preprocess_mask(mask):
226
+ mask = ops.convert_to_numpy(mask)
227
+ mask = np.expand_dims(mask, axis=-1)
228
+ mask = np.where(mask > threshold, 1.0, 0.0)
229
+ mask = filters.gaussian(mask, sigma=sigma)
230
+ mask = (mask - ops.min(mask)) / (ops.max(mask) - ops.min(mask) + 1e-8)
231
+ return mask
232
+
233
+ mask_vent = _preprocess_mask(mask_vent)
234
+ mask_sept = _preprocess_mask(mask_sept)
235
+ return mask_vent, mask_sept
236
+
237
+ def _get_dark_mask(self, images):
238
+ min_val = self.diffusion_model.input_range[0]
239
+ dark_mask = ops.where(ops.abs(images - min_val) < 1e-6, 1.0, 0.0)
240
+ return dark_mask
241
+
242
+ def make_omega_map(
243
+ self, images, mask_params, fixed_mask_params, skeleton_params, guidance_kwargs
244
+ ):
245
+ masks = self.get_masks(images, mask_params, fixed_mask_params, skeleton_params)
246
+
247
+ masks_vent = masks["vent"]
248
+ masks_sept = masks["sept"]
249
+ masks_fixed = masks["fixed"]
250
+ masks_skeleton = masks["skeleton"]
251
+ masks_dark = masks["dark"]
252
+
253
+ masks_strong = ops.clip(
254
+ masks_sept + masks_fixed + masks_skeleton + masks_dark, 0, 1
255
+ )
256
+
257
+ # background = not masks_strong, not vent
258
+ background = ops.where(masks_strong < 0.1, 1.0, 0.0) * ops.where(
259
+ masks_vent == 0, 1.0, 0.0
260
+ )
261
+
262
+ masks_vent_filtered = masks_vent * (1.0 - masks_strong)
263
+
264
+ per_pixel_omega = (
265
+ guidance_kwargs["omega"] * background
266
+ + guidance_kwargs["omega_vent"] * masks_vent_filtered
267
+ + guidance_kwargs["omega_sept"] * masks_strong
268
+ )
269
+
270
+ haze_mask_components = (masks_vent > 0.5) * (1 - masks_strong > 0.5)
271
+
272
+ haze_mask = []
273
+ for i, m in enumerate(haze_mask_components):
274
+ if scipy.ndimage.label(m)[1] > 1:
275
+ # masks_strong _splits_ masks_vent in 2 or more components
276
+ # so we fall back to masks_vent
277
+ haze_mask.append(masks_vent[i])
278
+ # also remove guidance from this region to avoid bringing haze in
279
+ per_pixel_omega = per_pixel_omega.at[i].set(
280
+ per_pixel_omega[i] * (1 - masks_vent[i])
281
+ )
282
+ else:
283
+ # masks_strong 'shaves off' some of masks_vent,
284
+ # where there is tissue
285
+ haze_mask.append((masks_vent * (1 - masks_strong))[i])
286
+ haze_mask = ops.stack(haze_mask, axis=0)
287
+
288
+ masks["per_pixel_omega"] = per_pixel_omega
289
+ masks["haze"] = haze_mask
290
+
291
+ return masks
292
+
293
+ def get_masks(self, images, mask_params, fixed_mask_params, skeleton_params):
294
+ """Generate a mask from the input images."""
295
+ masks_vent, masks_sept = self._get_segmentation_mask(images, **mask_params)
296
+ masks_fixed = self._get_fixed_mask(images, **fixed_mask_params)
297
+ masks_skeleton = extract_skeleton(
298
+ images, self.diffusion_model.input_range, **skeleton_params
299
+ )
300
+ masks_dark = self._get_dark_mask(images)
301
+ return {
302
+ "vent": masks_vent,
303
+ "sept": masks_sept,
304
+ "fixed": masks_fixed,
305
+ "skeleton": masks_skeleton,
306
+ "dark": masks_dark,
307
+ }
308
+
309
+ def compute_error(
310
+ self,
311
+ noisy_images,
312
+ measurements,
313
+ noise_rates,
314
+ signal_rates,
315
+ per_pixel_omega,
316
+ haze_mask,
317
+ eta=0.01,
318
+ smooth_l1_beta=0.5,
319
+ **kwargs,
320
+ ):
321
+ """Compute measurement error for diffusion posterior sampling.
322
+
323
+ Args:
324
+ noisy_images: Noisy images.
325
+ measurement: Target measurement.
326
+ operator: Forward operator.
327
+ noise_rates: Current noise rates.
328
+ signal_rates: Current signal rates.
329
+ omega: Weight for the measurement error.
330
+ omega_mask: Weight for the measurement error at the mask region.
331
+ omega_haze_prior: Weight for the haze prior penalty.
332
+ **kwargs: Additional arguments for the operator.
333
+
334
+ Returns:
335
+ Tuple of (measurement_error, (pred_noises, pred_images))
336
+ """
337
+ pred_noises, pred_images = self.diffusion_model.denoise(
338
+ noisy_images,
339
+ noise_rates,
340
+ signal_rates,
341
+ training=False,
342
+ )
343
+
344
+ measurement_error = L2(
345
+ per_pixel_omega
346
+ * (measurements - self.operator.forward(pred_images, **kwargs))
347
+ )
348
+
349
+ hazy_pixels = pred_images * haze_mask
350
+
351
+ # L1 penalty on haze pixels
352
+ # add +1 to make -1 (=black) the 'sparse' value
353
+ haze_prior_error = smooth_L1(hazy_pixels + 1, beta=smooth_l1_beta)
354
+
355
+ total_error = measurement_error + eta * haze_prior_error
356
+
357
+ return total_error, (pred_noises, pred_images)
358
+
359
+
360
+ def init(config):
361
+ """Initialize models, operator, and guidance objects for semantic-dps dehazing."""
362
+
363
+ operator = IdentityOperator()
364
+
365
+ diffusion_model = DiffusionModel.from_preset(
366
+ config.diffusion_model_path,
367
+ )
368
+ log.success(
369
+ f"Diffusion model loaded from {log.yellow(config.diffusion_model_path)}"
370
+ )
371
+ segmentation_model = load_segmentation_model(config.segmentation_model_path)
372
+
373
+ log.success(
374
+ f"Segmentation model loaded from {log.yellow(config.segmentation_model_path)}"
375
+ )
376
+
377
+ guidance_fn = SemanticDPS(
378
+ diffusion_model=diffusion_model,
379
+ segmentation_model=segmentation_model,
380
+ operator=operator,
381
+ )
382
+ diffusion_model._init_operator_and_guidance(operator, guidance_fn)
383
+
384
+ return diffusion_model
385
+
386
+
387
+ def load_segmentation_model(path):
388
+ """Load segmentation model"""
389
+ segmentation_model = keras.saving.load_model(path)
390
+ return segmentation_model
391
+
392
+
393
+ def run(
394
+ hazy_images: any,
395
+ diffusion_model: DiffusionModel,
396
+ seed,
397
+ guidance_kwargs: dict,
398
+ mask_params: dict,
399
+ fixed_mask_params: dict,
400
+ skeleton_params: dict,
401
+ batch_size: int = 4,
402
+ diffusion_steps: int = 100,
403
+ initial_diffusion_step: int = 0,
404
+ threshold_output_quantile: float = None,
405
+ preserve_bottom_percent: float = 30.0,
406
+ bottom_transition_width: float = 10.0,
407
+ verbose: bool = True,
408
+ ):
409
+ input_range = diffusion_model.input_range
410
+
411
+ hazy_images = preprocess(hazy_images, normalization_range=input_range)
412
+
413
+ pred_tissue_images = []
414
+ masks_out = []
415
+ num_images = hazy_images.shape[0]
416
+ num_batches = (num_images + batch_size - 1) // batch_size
417
+
418
+ progbar = keras.utils.Progbar(num_batches, verbose=verbose)
419
+ i = 0
420
+ batch_idx = 0
421
+ for i in range(num_batches):
422
+ batch = hazy_images[i * batch_size : (i * batch_size) + batch_size]
423
+
424
+ masks = diffusion_model.guidance_fn.make_omega_map(
425
+ batch, mask_params, fixed_mask_params, skeleton_params, guidance_kwargs
426
+ )
427
+
428
+ batch_images = diffusion_model.posterior_sample(
429
+ batch,
430
+ n_samples=1,
431
+ n_steps=diffusion_steps,
432
+ initial_step=initial_diffusion_step,
433
+ seed=seed,
434
+ verbose=True,
435
+ per_pixel_omega=masks["per_pixel_omega"],
436
+ haze_mask=masks["haze"],
437
+ eta=guidance_kwargs["eta"],
438
+ smooth_l1_beta=guidance_kwargs["smooth_l1_beta"],
439
+ )
440
+ batch_images = ops.take(batch_images, 0, axis=1)
441
+
442
+ pred_tissue_images.append(batch_images)
443
+ masks_out.append(masks)
444
+ batch_idx += 1
445
+ progbar.update(batch_idx)
446
+ i += batch_size
447
+
448
+ pred_tissue_images = ops.concatenate(pred_tissue_images, axis=0)
449
+ masks_out = {
450
+ key: ops.concatenate([m[key] for m in masks_out], axis=0)
451
+ for key in masks_out[0].keys()
452
+ }
453
+ pred_haze_images = hazy_images - pred_tissue_images - 1
454
+
455
+ if threshold_output_quantile is not None:
456
+ threshold_value = ops.quantile(
457
+ pred_tissue_images, threshold_output_quantile, axis=(1, 2), keepdims=True
458
+ )
459
+ pred_tissue_images = ops.where(
460
+ pred_tissue_images < threshold_value, input_range[0], pred_tissue_images
461
+ )
462
+
463
+ # Apply bottom preservation with smooth transition
464
+ if preserve_bottom_percent > 0:
465
+ pred_tissue_images = apply_bottom_preservation(
466
+ pred_tissue_images,
467
+ hazy_images,
468
+ preserve_bottom_percent=preserve_bottom_percent,
469
+ transition_width=bottom_transition_width,
470
+ )
471
+
472
+ pred_tissue_images = postprocess(pred_tissue_images, input_range)
473
+ hazy_images = postprocess(hazy_images, input_range)
474
+ pred_haze_images = postprocess(pred_haze_images, input_range)
475
+
476
+ return hazy_images, pred_tissue_images, pred_haze_images, masks_out
477
+
478
+
479
+ def add_shape_from_mask(ax, mask, **kwargs):
480
+ """add a shape to axis from mask array.
481
+
482
+ Args:
483
+ ax (plt.ax): matplotlib axis
484
+ mask (ndarray): numpy array with non-zero
485
+ shape defining the region of interest.
486
+ Kwargs:
487
+ edgecolor (str): color of the shape's edge
488
+ facecolor (str): color of the shape's face
489
+ linewidth (int): width of the shape's edge
490
+
491
+ Returns:
492
+ plt.ax: matplotlib axis with shape added
493
+ """
494
+ # Pad mask to ensure edge contours are found
495
+ padded_mask = np.pad(mask, pad_width=1, mode="constant", constant_values=0)
496
+ contours = measure.find_contours(padded_mask, 0.5)
497
+ patches = []
498
+ for contour in contours:
499
+ # Remove padding offset
500
+ contour -= 1
501
+ path = pltPath(contour[:, ::-1])
502
+ patch = PathPatch(path, **kwargs)
503
+ patches.append(ax.add_patch(patch))
504
+ return patches
505
+
506
+
507
+ def plot_batch_with_named_masks(
508
+ images, masks_dict, mask_colors=None, titles=None, **kwargs
509
+ ):
510
+ """
511
+ Plot batch of images in rows, each column overlays a different mask from the dict.
512
+ Mask labels are shown as column titles. If mask name is 'per_pixel_omega', show it
513
+ directly with inferno colormap (no overlay).
514
+
515
+ Args:
516
+ images: np.ndarray, shape (batch, height, width, channels)
517
+ masks_dict: dict of {name: mask}, each mask shape (batch, height, width, channels)
518
+ mask_colors: dict of {name: color} or None (default colors used)
519
+ """
520
+ mask_names = list(masks_dict.keys())
521
+ batch_size = images.shape[0]
522
+ default_colors = ["red", "green", "#33aaff", "yellow", "magenta", "cyan"]
523
+ mask_colors = mask_colors or {
524
+ name: default_colors[i % len(default_colors)]
525
+ for i, name in enumerate(mask_names)
526
+ }
527
+
528
+ # Prepare images for each column
529
+ columns = []
530
+ cmaps = []
531
+ for name in mask_names:
532
+ if name == "per_pixel_omega":
533
+ mask_np = np.array(masks_dict[name])
534
+ columns.append(np.squeeze(mask_np))
535
+ cmaps.append(["inferno"] * batch_size)
536
+ else:
537
+ columns.append(np.squeeze(images))
538
+ cmaps.append(["gray"] * batch_size)
539
+
540
+ # Stack columns: shape (num_columns, batch, ...)
541
+ all_images = np.stack(columns, axis=0) # (num_columns, batch, ...)
542
+ # Rearrange to (batch, num_columns, ...)
543
+ all_images = (
544
+ np.transpose(all_images, (1, 0, 2, 3, 4))
545
+ if all_images.ndim == 5
546
+ else np.transpose(all_images, (1, 0, 2, 3))
547
+ )
548
+ # Flatten to (batch * num_columns, ...)
549
+ all_images = all_images.reshape(batch_size * len(mask_names), *images.shape[1:])
550
+
551
+ # Flatten cmaps for plot_image_grid in the same order as images
552
+ flat_cmaps = []
553
+ for row in range(batch_size):
554
+ for col in range(len(mask_names)):
555
+ flat_cmaps.append(cmaps[col][row])
556
+
557
+ fig, _ = plot_image_grid(
558
+ all_images,
559
+ ncols=len(mask_names),
560
+ remove_axis=False,
561
+ cmap=flat_cmaps,
562
+ figsize=(8, 3.3),
563
+ **kwargs,
564
+ )
565
+
566
+ # Overlay masks for non-per_pixel_omega columns
567
+ for col_idx, name in enumerate(mask_names):
568
+ if name == "per_pixel_omega":
569
+ continue
570
+ mask_np = np.array(masks_dict[name])
571
+ axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)]
572
+ for ax, mask_img in zip(axes, mask_np):
573
+ add_shape_from_mask(
574
+ ax, mask_img.squeeze(), color=mask_colors[name], alpha=0.3
575
+ )
576
+
577
+ # Add column titles
578
+ row_idx = 0
579
+ if titles is None:
580
+ titles = mask_names
581
+ for col_idx, name in enumerate(titles):
582
+ ax_idx = row_idx * len(mask_names) + col_idx
583
+ fig.axes[ax_idx].set_title(name, fontsize=9, color="white")
584
+ fig.axes[ax_idx].set_facecolor("black")
585
+
586
+ # Add colorbar for per_pixel_omega if present
587
+ if "per_pixel_omega" in mask_names:
588
+ col_idx = mask_names.index("per_pixel_omega")
589
+ axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)]
590
+
591
+ # Get vertical bounds of the subplot column
592
+ top_ax = axes[0]
593
+ bottom_ax = axes[-1]
594
+ top_pos = top_ax.get_position()
595
+ bottom_pos = bottom_ax.get_position()
596
+
597
+ full_y0 = bottom_pos.y0
598
+ full_y1 = top_pos.y1
599
+ full_height = full_y1 - full_y0
600
+
601
+ # Manually shrink to 80% of full height and center vertically
602
+ scale = 0.8
603
+ height = full_height * scale
604
+ y0 = full_y0 + (full_height - height) / 2
605
+
606
+ x0 = top_pos.x1 + 0.015 # Horizontal position to the right
607
+ width = 0.015 # Thin bar
608
+
609
+ # Add colorbar axis
610
+ cax = fig.add_axes([x0, y0, width, height])
611
+
612
+ im = axes[0].get_images()[0] if axes[0].get_images() else None
613
+ cbar = fig.colorbar(im, cax=cax)
614
+ cbar.set_label(r"Guidance weighting \mathbf{p}")
615
+ cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6))
616
+ cbar.ax.yaxis.set_tick_params(labelsize=7)
617
+ cbar.ax.yaxis.label.set_size(8)
618
+
619
+ return fig
620
+
621
+
622
+ def plot_dehazed_results(
623
+ hazy_images,
624
+ pred_tissue_images,
625
+ pred_haze_images,
626
+ diffusion_model,
627
+ titles=("Hazy", "Dehazed", "Haze"),
628
+ ):
629
+ """Create and save visualization with optional mask overlays."""
630
+
631
+ # Create the processed image stack using the helper function
632
+ input_shape = diffusion_model.input_shape
633
+ stack_images = ops.stack(
634
+ [
635
+ hazy_images,
636
+ pred_tissue_images,
637
+ pred_haze_images,
638
+ ]
639
+ )
640
+ stack_images = ops.reshape(stack_images, (-1, input_shape[0], input_shape[1]))
641
+
642
+ # Define labels based on what we're showing
643
+ fig, _ = plot_image_grid(
644
+ stack_images,
645
+ ncols=len(hazy_images),
646
+ remove_axis=False,
647
+ vmin=0,
648
+ vmax=255,
649
+ )
650
+ # Set labels and styling
651
+ for i, ax in enumerate(fig.axes):
652
+ if i % len(hazy_images) == 0:
653
+ label = titles[(i // len(hazy_images)) % len(titles)]
654
+ ax.set_ylabel(label, fontsize=12)
655
+
656
+ return fig
657
+
658
+
659
+ def main(
660
+ input_folder: str = "./assets",
661
+ output_folder: str = "./temp",
662
+ num_imgs_plot: int = 4,
663
+ device: str = "auto:1",
664
+ config: str = "configs/semantic_dps.yaml",
665
+ ):
666
+ num_img = num_imgs_plot
667
+
668
+ zea.visualize.set_mpl_style()
669
+ init_device(device)
670
+
671
+ config = Config.from_yaml(config)
672
+ seed = jax.random.PRNGKey(config.seed)
673
+
674
+ paths = list(Path(input_folder).glob("*.png"))
675
+
676
+ output_folder = Path(output_folder)
677
+
678
+ images = []
679
+ for path in paths:
680
+ image = zea.io_lib.load_image(path)
681
+ images.append(image)
682
+ images = ops.stack(images, axis=0)
683
+
684
+ diffusion_model = init(config)
685
+
686
+ hazy_images, pred_tissue_images, pred_haze_images, masks = run(
687
+ images,
688
+ diffusion_model=diffusion_model,
689
+ seed=seed,
690
+ **config.params,
691
+ )
692
+
693
+ output_folder.mkdir(parents=True, exist_ok=True)
694
+
695
+ for image, path in zip(pred_tissue_images, paths):
696
+ image = ops.convert_to_numpy(image)
697
+ file_name = path.name
698
+ Image.fromarray(image).save(output_folder / file_name)
699
+
700
+ fig = plot_dehazed_results(
701
+ hazy_images[:num_img],
702
+ pred_tissue_images[:num_img],
703
+ pred_haze_images[:num_img],
704
+ diffusion_model,
705
+ titles=[
706
+ r"Hazy $\mathbf{y}$",
707
+ r"Dehazed $\mathbf{\hat{x}}$",
708
+ r"Haze $\mathbf{\hat{h}}$",
709
+ ],
710
+ )
711
+ path = Path("dehazed_results.png")
712
+ save_kwargs = {"bbox_inches": "tight", "dpi": 300}
713
+ fig.savefig(path, **save_kwargs)
714
+ fig.savefig(path.with_suffix(".pdf"), **save_kwargs)
715
+ log.success(f"Segmentation steps saved to {log.yellow(path)}")
716
+
717
+ masks_viz = copy.deepcopy(masks)
718
+ masks_viz.pop("haze")
719
+
720
+ masks_viz = {k: v[:num_img] for k, v in masks_viz.items()}
721
+
722
+ fig = plot_batch_with_named_masks(
723
+ images[:num_img],
724
+ masks_viz,
725
+ titles=[
726
+ r"Ventricle $v(\mathbf{y})$",
727
+ r"Septum $s(\mathbf{y})$",
728
+ r"Fixed",
729
+ r"Skeleton $t(\mathbf{y})$",
730
+ r"Dark $b(\mathbf{y})$",
731
+ r"Guidance $d(\mathbf{y})$",
732
+ ],
733
+ )
734
+ path = Path("segmentation_steps.png")
735
+ fig.savefig(path, **save_kwargs)
736
+ fig.savefig(path.with_suffix(".pdf"), **save_kwargs)
737
+ log.success(f"Segmentation steps saved to {log.yellow(path)}")
738
+
739
+ plt.close("all")
740
+
741
+
742
+ if __name__ == "__main__":
743
+ tyro.cli(main)