tristan-deep commited on
Commit
eadd412
·
1 Parent(s): 0c7f2c0

fix eval and plotting refactor

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. README.md +1 -1
  3. eval.py +38 -60
  4. fid_score.py +3 -2
  5. main.py +3 -187
  6. plots.py +254 -0
.gitignore CHANGED
@@ -3,3 +3,5 @@
3
  temp/
4
  *.png
5
  *.pdf
 
 
 
3
  temp/
4
  *.png
5
  *.pdf
6
+ *.hash
7
+ *.npz
README.md CHANGED
@@ -7,7 +7,7 @@
7
  <a href="https://keras.io/"><img src="https://img.shields.io/badge/Keras-EE4C2C?logo=keras&logoColor=white" alt="Keras"></a>
8
  </p>
9
  <h3>
10
- <span style="display:inline-block; margin: 0 20px;">
11
  <a href="https://example.com/tristan-stevens">Tristan Stevens</a>
12
  </span>
13
  <span style="display:inline-block; margin: 0 20px;">
 
7
  <a href="https://keras.io/"><img src="https://img.shields.io/badge/Keras-EE4C2C?logo=keras&logoColor=white" alt="Keras"></a>
8
  </p>
9
  <h3>
10
+ <span style="display:inline-block; margin: 0 40px;">
11
  <a href="https://example.com/tristan-stevens">Tristan Stevens</a>
12
  </span>
13
  <span style="display:inline-block; margin: 0 20px;">
eval.py CHANGED
@@ -2,16 +2,17 @@ import warnings
2
  from glob import glob
3
  from pathlib import Path
4
 
5
- import matplotlib.pyplot as plt
6
  import numpy as np
7
  import torch
8
  import tyro
9
  from PIL import Image
10
  from scipy.ndimage import binary_erosion, distance_transform_edt
11
  from scipy.stats import ks_2samp
 
12
  from zea.io_lib import load_image
13
 
14
  import fid_score
 
15
 
16
 
17
  def calculate_fid_score(denoised_image_dirs, ground_truth_dir):
@@ -207,64 +208,38 @@ def calculate_final_score(aggregates):
207
  return 0
208
 
209
 
210
- def plot_metrics(metrics, limits, out_path):
211
- plt.style.use("seaborn-v0_8-darkgrid")
212
- fig, axes = plt.subplots(1, len(metrics), figsize=(7.2, 2.7), dpi=600)
213
- colors = ["#0057b7", "#ffb300", "#008744", "#d62d20"]
214
- # Arrow direction: for up, for down
215
- metric_labels = {
216
- "CNR": r"CNR $\uparrow$",
217
- "gCNR": r"gCNR $\uparrow$",
218
- "KS_A": r"KS$_{septum}$ $\downarrow$",
219
- "KS_B": r"KS$_{ventricle}$ $\uparrow$",
220
- }
221
- for idx, (ax, (name, values)) in enumerate(zip(axes, metrics.items())):
222
- ax.hist(
223
- values,
224
- bins=30,
225
- color=colors[idx % len(colors)],
226
- alpha=0.85,
227
- edgecolor="black",
228
- linewidth=0.7,
229
- )
230
- ax.set_xlabel(metric_labels.get(name, name), fontsize=11)
231
- ax.set_ylabel("Count", fontsize=10)
232
- # Draw limits
233
- if name in limits:
234
- for lim in limits[name]:
235
- ax.axvline(lim, color="crimson", linestyle="--", lw=1.2)
236
- ax.spines["top"].set_visible(False)
237
- ax.spines["right"].set_visible(False)
238
- ax.tick_params(axis="both", which="major", labelsize=9)
239
- fig.tight_layout(pad=1.5)
240
- fig.savefig(out_path, bbox_inches="tight", dpi=600)
241
- plt.close(fig)
242
-
243
-
244
- def main(folder: str, roi_folder: str, reference_folder: str):
245
  folder = Path(folder)
 
246
  roi_folder = Path(roi_folder)
247
  reference_folder = Path(reference_folder)
248
 
249
  folder_files = set(f.name for f in folder.glob("*.png"))
 
250
  roi_files = set(f.name for f in roi_folder.glob("*.png"))
251
- ref_files = set(f.name for f in reference_folder.glob("*.png"))
252
 
253
  print(f"Found {len(folder_files)} .png files in output folder: {folder}")
 
254
  print(f"Found {len(roi_files)} .png files in ROI folder: {roi_folder}")
255
- print(f"Found {len(ref_files)} .png files in reference folder: {reference_folder}")
256
 
257
  # Find intersection of filenames
258
- common_files = sorted(folder_files & roi_files & ref_files)
259
- print(f"Found {len(common_files)} images present in all folders.")
260
- if len(common_files) == 0:
261
- print("No matching images found in all folders. Check your folder contents.")
262
- print(f"Output folder files: {sorted(folder_files)}")
263
- print(f"ROI folder files: {sorted(roi_files)}")
264
- print(f"Reference folder files: {sorted(ref_files)}")
265
- assert len(common_files) > 0, (
266
- "No matching .png files in all folders. Cannot proceed."
267
- )
268
 
269
  metrics = {"CNR": [], "gCNR": [], "KS_A": [], "KS_B": []}
270
  limits = {
@@ -275,28 +250,26 @@ def main(folder: str, roi_folder: str, reference_folder: str):
275
  }
276
 
277
  for name in common_files:
278
- our_path = folder / name
 
279
  roi_path = roi_folder / name
280
- ref_path = reference_folder / name
281
-
282
- assert our_path.exists(), f"Missing file in output folder: {our_path}"
283
- assert roi_path.exists(), f"Missing file in ROI folder: {roi_path}"
284
- assert ref_path.exists(), f"Missing file in reference folder: {ref_path}"
285
 
286
  try:
287
- img = np.array(load_image(str(our_path)))
288
- img_ref = np.array(load_image(str(ref_path)))
289
  except Exception as e:
290
  print(f"Error loading image {name}: {e}")
291
  continue
292
 
293
  # CNR/gCNR
294
- cnr_gcnr = calculate_cnr_gcnr(img, str(roi_path))
295
  metrics["CNR"].append(cnr_gcnr[0][0])
296
  metrics["gCNR"].append(cnr_gcnr[0][1])
297
 
298
  # KS statistics
299
- ks_a, _, ks_b, _ = calculate_ks_statistics(img_ref, img, str(roi_path))
 
 
300
  metrics["KS_A"].append(ks_a)
301
  metrics["KS_B"].append(ks_b)
302
 
@@ -308,8 +281,13 @@ def main(folder: str, roi_folder: str, reference_folder: str):
308
  for k, (mean, std, minv, maxv) in stats.items():
309
  print(f"{k}: mean={mean:.3f}, std={std:.3f}, min={minv:.3f}, max={maxv:.3f}")
310
 
311
- plot_metrics(metrics, limits, str(folder / "contrast_metrics.png"))
312
- print(f"Saved metrics plot to {folder / 'contrast_metrics.png'}")
 
 
 
 
 
313
 
314
  # Compute FID
315
  fid_image_paths = [str(folder / name) for name in common_files]
 
2
  from glob import glob
3
  from pathlib import Path
4
 
 
5
  import numpy as np
6
  import torch
7
  import tyro
8
  from PIL import Image
9
  from scipy.ndimage import binary_erosion, distance_transform_edt
10
  from scipy.stats import ks_2samp
11
+ from zea import log
12
  from zea.io_lib import load_image
13
 
14
  import fid_score
15
+ from plots import plot_metrics
16
 
17
 
18
  def calculate_fid_score(denoised_image_dirs, ground_truth_dir):
 
208
  return 0
209
 
210
 
211
+ def main(folder: str, noisy_folder: str, roi_folder: str, reference_folder: str):
212
+ """Evaluate the dehazing algorithm.
213
+
214
+ Args:
215
+ folder (str): Path to the folder containing the dehazed images.
216
+ Used for evaluating all metrics.
217
+ noisy_folder (str): Path to the folder containing the noisy images.
218
+ Only used for KS statistics.
219
+ roi_folder (str): Path to the folder containing the ROI images.
220
+ Used for contrast and KS statistic metrics.
221
+ reference_folder (str): Path to the folder containing the reference images.
222
+ Used only for FID calculation.
223
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  folder = Path(folder)
225
+ noisy_folder = Path(noisy_folder)
226
  roi_folder = Path(roi_folder)
227
  reference_folder = Path(reference_folder)
228
 
229
  folder_files = set(f.name for f in folder.glob("*.png"))
230
+ noisy_files = set(f.name for f in noisy_folder.glob("*.png"))
231
  roi_files = set(f.name for f in roi_folder.glob("*.png"))
 
232
 
233
  print(f"Found {len(folder_files)} .png files in output folder: {folder}")
234
+ print(f"Found {len(noisy_files)} .png files in noisy folder: {noisy_folder}")
235
  print(f"Found {len(roi_files)} .png files in ROI folder: {roi_folder}")
 
236
 
237
  # Find intersection of filenames
238
+ common_files = sorted(folder_files & roi_files & noisy_files)
239
+ print(f"Found {len(common_files)} matching images in noisy/dehazed/roi folders")
240
+ assert len(common_files) > 0, (
241
+ "No matching .png files in all folders. Cannot proceed."
242
+ )
 
 
 
 
 
243
 
244
  metrics = {"CNR": [], "gCNR": [], "KS_A": [], "KS_B": []}
245
  limits = {
 
250
  }
251
 
252
  for name in common_files:
253
+ dehazed_path = folder / name
254
+ noisy_path = noisy_folder / name
255
  roi_path = roi_folder / name
 
 
 
 
 
256
 
257
  try:
258
+ img_dehazed = np.array(load_image(str(dehazed_path)))
259
+ img_noisy = np.array(load_image(str(noisy_path)))
260
  except Exception as e:
261
  print(f"Error loading image {name}: {e}")
262
  continue
263
 
264
  # CNR/gCNR
265
+ cnr_gcnr = calculate_cnr_gcnr(img_dehazed, str(roi_path))
266
  metrics["CNR"].append(cnr_gcnr[0][0])
267
  metrics["gCNR"].append(cnr_gcnr[0][1])
268
 
269
  # KS statistics
270
+ ks_a, _, ks_b, _ = calculate_ks_statistics(
271
+ img_noisy, img_dehazed, str(roi_path)
272
+ )
273
  metrics["KS_A"].append(ks_a)
274
  metrics["KS_B"].append(ks_b)
275
 
 
281
  for k, (mean, std, minv, maxv) in stats.items():
282
  print(f"{k}: mean={mean:.3f}, std={std:.3f}, min={minv:.3f}, max={maxv:.3f}")
283
 
284
+ fig = plot_metrics(metrics, limits, "contrast_metrics.png")
285
+
286
+ path = Path("contrast_metrics.png")
287
+ save_kwargs = {"bbox_inches": "tight", "dpi": 300}
288
+ fig.savefig(path, **save_kwargs)
289
+ fig.savefig(path.with_suffix(".pdf"), **save_kwargs)
290
+ log.success(f"Metrics plot saved to {log.yellow(path)}")
291
 
292
  # Compute FID
293
  fid_image_paths = [str(folder / name) for name in common_files]
fid_score.py CHANGED
@@ -88,6 +88,7 @@ parser.add_argument(
88
  )
89
 
90
  IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
 
91
 
92
 
93
  class ImagePathDataset(torch.utils.data.Dataset):
@@ -277,7 +278,7 @@ def compute_statistics_of_path(path, model, batch_size, dims, device, num_worker
277
 
278
 
279
  def _fid_cache_paths():
280
- tmp_dir = pathlib.Path("tmp")
281
  tmp_dir.mkdir(exist_ok=True)
282
  stats_path = tmp_dir / "fid_stats.npz"
283
  hash_path = tmp_dir / "fid_stats.hash"
@@ -391,7 +392,7 @@ def calculate_fid_with_cached_ground_truth(
391
  continue
392
  return hash_md5.hexdigest()
393
 
394
- tmp_dir = pathlib.Path("tmp")
395
  tmp_dir.mkdir(exist_ok=True)
396
  stats_path = tmp_dir / "fid_stats.npz"
397
  hash_path = tmp_dir / "fid_stats.hash"
 
88
  )
89
 
90
  IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
91
+ TEMP_DIR = pathlib.Path("temp")
92
 
93
 
94
  class ImagePathDataset(torch.utils.data.Dataset):
 
278
 
279
 
280
  def _fid_cache_paths():
281
+ tmp_dir = TEMP_DIR
282
  tmp_dir.mkdir(exist_ok=True)
283
  stats_path = tmp_dir / "fid_stats.npz"
284
  hash_path = tmp_dir / "fid_stats.hash"
 
392
  continue
393
  return hash_md5.hexdigest()
394
 
395
+ tmp_dir = TEMP_DIR
396
  tmp_dir.mkdir(exist_ok=True)
397
  stats_path = tmp_dir / "fid_stats.npz"
398
  hash_path = tmp_dir / "fid_stats.hash"
main.py CHANGED
@@ -1,9 +1,6 @@
1
  import copy
2
- import os
3
  from pathlib import Path
4
 
5
- os.environ["KERAS_BACKEND"] = "jax"
6
-
7
  import jax
8
  import keras
9
  import matplotlib.pyplot as plt
@@ -12,10 +9,8 @@ import scipy
12
  import tyro
13
  import zea
14
  from keras import ops
15
- from matplotlib.patches import PathPatch
16
- from matplotlib.path import Path as pltPath
17
  from PIL import Image
18
- from skimage import filters, measure, morphology
19
  from zea import Config, init_device, log
20
  from zea.internal.operators import Operator
21
  from zea.models.diffusion import (
@@ -25,7 +20,8 @@ from zea.models.diffusion import (
25
  )
26
  from zea.tensor_ops import L2
27
  from zea.utils import translate
28
- from zea.visualize import plot_image_grid
 
29
 
30
 
31
  def L1(x):
@@ -476,186 +472,6 @@ def run(
476
  return hazy_images, pred_tissue_images, pred_haze_images, masks_out
477
 
478
 
479
- def add_shape_from_mask(ax, mask, **kwargs):
480
- """add a shape to axis from mask array.
481
-
482
- Args:
483
- ax (plt.ax): matplotlib axis
484
- mask (ndarray): numpy array with non-zero
485
- shape defining the region of interest.
486
- Kwargs:
487
- edgecolor (str): color of the shape's edge
488
- facecolor (str): color of the shape's face
489
- linewidth (int): width of the shape's edge
490
-
491
- Returns:
492
- plt.ax: matplotlib axis with shape added
493
- """
494
- # Pad mask to ensure edge contours are found
495
- padded_mask = np.pad(mask, pad_width=1, mode="constant", constant_values=0)
496
- contours = measure.find_contours(padded_mask, 0.5)
497
- patches = []
498
- for contour in contours:
499
- # Remove padding offset
500
- contour -= 1
501
- path = pltPath(contour[:, ::-1])
502
- patch = PathPatch(path, **kwargs)
503
- patches.append(ax.add_patch(patch))
504
- return patches
505
-
506
-
507
- def plot_batch_with_named_masks(
508
- images, masks_dict, mask_colors=None, titles=None, **kwargs
509
- ):
510
- """
511
- Plot batch of images in rows, each column overlays a different mask from the dict.
512
- Mask labels are shown as column titles. If mask name is 'per_pixel_omega', show it
513
- directly with inferno colormap (no overlay).
514
-
515
- Args:
516
- images: np.ndarray, shape (batch, height, width, channels)
517
- masks_dict: dict of {name: mask}, each mask shape (batch, height, width, channels)
518
- mask_colors: dict of {name: color} or None (default colors used)
519
- """
520
- mask_names = list(masks_dict.keys())
521
- batch_size = images.shape[0]
522
- default_colors = ["red", "green", "#33aaff", "yellow", "magenta", "cyan"]
523
- mask_colors = mask_colors or {
524
- name: default_colors[i % len(default_colors)]
525
- for i, name in enumerate(mask_names)
526
- }
527
-
528
- # Prepare images for each column
529
- columns = []
530
- cmaps = []
531
- for name in mask_names:
532
- if name == "per_pixel_omega":
533
- mask_np = np.array(masks_dict[name])
534
- columns.append(np.squeeze(mask_np))
535
- cmaps.append(["inferno"] * batch_size)
536
- else:
537
- columns.append(np.squeeze(images))
538
- cmaps.append(["gray"] * batch_size)
539
-
540
- # Stack columns: shape (num_columns, batch, ...)
541
- all_images = np.stack(columns, axis=0) # (num_columns, batch, ...)
542
- # Rearrange to (batch, num_columns, ...)
543
- all_images = (
544
- np.transpose(all_images, (1, 0, 2, 3, 4))
545
- if all_images.ndim == 5
546
- else np.transpose(all_images, (1, 0, 2, 3))
547
- )
548
- # Flatten to (batch * num_columns, ...)
549
- all_images = all_images.reshape(batch_size * len(mask_names), *images.shape[1:])
550
-
551
- # Flatten cmaps for plot_image_grid in the same order as images
552
- flat_cmaps = []
553
- for row in range(batch_size):
554
- for col in range(len(mask_names)):
555
- flat_cmaps.append(cmaps[col][row])
556
-
557
- fig, _ = plot_image_grid(
558
- all_images,
559
- ncols=len(mask_names),
560
- remove_axis=False,
561
- cmap=flat_cmaps,
562
- figsize=(8, 3.3),
563
- **kwargs,
564
- )
565
-
566
- # Overlay masks for non-per_pixel_omega columns
567
- for col_idx, name in enumerate(mask_names):
568
- if name == "per_pixel_omega":
569
- continue
570
- mask_np = np.array(masks_dict[name])
571
- axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)]
572
- for ax, mask_img in zip(axes, mask_np):
573
- add_shape_from_mask(
574
- ax, mask_img.squeeze(), color=mask_colors[name], alpha=0.3
575
- )
576
-
577
- # Add column titles
578
- row_idx = 0
579
- if titles is None:
580
- titles = mask_names
581
- for col_idx, name in enumerate(titles):
582
- ax_idx = row_idx * len(mask_names) + col_idx
583
- fig.axes[ax_idx].set_title(name, fontsize=9, color="white")
584
- fig.axes[ax_idx].set_facecolor("black")
585
-
586
- # Add colorbar for per_pixel_omega if present
587
- if "per_pixel_omega" in mask_names:
588
- col_idx = mask_names.index("per_pixel_omega")
589
- axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)]
590
-
591
- # Get vertical bounds of the subplot column
592
- top_ax = axes[0]
593
- bottom_ax = axes[-1]
594
- top_pos = top_ax.get_position()
595
- bottom_pos = bottom_ax.get_position()
596
-
597
- full_y0 = bottom_pos.y0
598
- full_y1 = top_pos.y1
599
- full_height = full_y1 - full_y0
600
-
601
- # Manually shrink to 80% of full height and center vertically
602
- scale = 0.8
603
- height = full_height * scale
604
- y0 = full_y0 + (full_height - height) / 2
605
-
606
- x0 = top_pos.x1 + 0.015 # Horizontal position to the right
607
- width = 0.015 # Thin bar
608
-
609
- # Add colorbar axis
610
- cax = fig.add_axes([x0, y0, width, height])
611
-
612
- im = axes[0].get_images()[0] if axes[0].get_images() else None
613
- cbar = fig.colorbar(im, cax=cax)
614
- cbar.set_label(r"Guidance weighting \mathbf{p}")
615
- cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6))
616
- cbar.ax.yaxis.set_tick_params(labelsize=7)
617
- cbar.ax.yaxis.label.set_size(8)
618
-
619
- return fig
620
-
621
-
622
- def plot_dehazed_results(
623
- hazy_images,
624
- pred_tissue_images,
625
- pred_haze_images,
626
- diffusion_model,
627
- titles=("Hazy", "Dehazed", "Haze"),
628
- ):
629
- """Create and save visualization with optional mask overlays."""
630
-
631
- # Create the processed image stack using the helper function
632
- input_shape = diffusion_model.input_shape
633
- stack_images = ops.stack(
634
- [
635
- hazy_images,
636
- pred_tissue_images,
637
- pred_haze_images,
638
- ]
639
- )
640
- stack_images = ops.reshape(stack_images, (-1, input_shape[0], input_shape[1]))
641
-
642
- # Define labels based on what we're showing
643
- fig, _ = plot_image_grid(
644
- stack_images,
645
- ncols=len(hazy_images),
646
- remove_axis=False,
647
- vmin=0,
648
- vmax=255,
649
- )
650
- # Set labels and styling
651
- for i, ax in enumerate(fig.axes):
652
- if i % len(hazy_images) == 0:
653
- label = titles[(i // len(hazy_images)) % len(titles)]
654
- ax.set_ylabel(label, fontsize=12)
655
-
656
- return fig
657
-
658
-
659
  def main(
660
  input_folder: str = "./assets",
661
  output_folder: str = "./temp",
 
1
  import copy
 
2
  from pathlib import Path
3
 
 
 
4
  import jax
5
  import keras
6
  import matplotlib.pyplot as plt
 
9
  import tyro
10
  import zea
11
  from keras import ops
 
 
12
  from PIL import Image
13
+ from skimage import filters, morphology
14
  from zea import Config, init_device, log
15
  from zea.internal.operators import Operator
16
  from zea.models.diffusion import (
 
20
  )
21
  from zea.tensor_ops import L2
22
  from zea.utils import translate
23
+
24
+ from plots import plot_batch_with_named_masks, plot_dehazed_results
25
 
26
 
27
  def L1(x):
 
472
  return hazy_images, pred_tissue_images, pred_haze_images, masks_out
473
 
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  def main(
476
  input_folder: str = "./assets",
477
  output_folder: str = "./temp",
plots.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from keras import ops
4
+ from matplotlib.patches import PathPatch
5
+ from matplotlib.path import Path as pltPath
6
+ from skimage import measure
7
+ from zea.visualize import plot_image_grid
8
+
9
+
10
+ def add_shape_from_mask(ax, mask, **kwargs):
11
+ """add a shape to axis from mask array.
12
+
13
+ Args:
14
+ ax (plt.ax): matplotlib axis
15
+ mask (ndarray): numpy array with non-zero
16
+ shape defining the region of interest.
17
+ Kwargs:
18
+ edgecolor (str): color of the shape's edge
19
+ facecolor (str): color of the shape's face
20
+ linewidth (int): width of the shape's edge
21
+
22
+ Returns:
23
+ plt.ax: matplotlib axis with shape added
24
+ """
25
+ # Pad mask to ensure edge contours are found
26
+ padded_mask = np.pad(mask, pad_width=1, mode="constant", constant_values=0)
27
+ contours = measure.find_contours(padded_mask, 0.5)
28
+ patches = []
29
+ for contour in contours:
30
+ # Remove padding offset
31
+ contour -= 1
32
+ path = pltPath(contour[:, ::-1])
33
+ patch = PathPatch(path, **kwargs)
34
+ patches.append(ax.add_patch(patch))
35
+ return patches
36
+
37
+
38
+ def plot_batch_with_named_masks(
39
+ images, masks_dict, mask_colors=None, titles=None, **kwargs
40
+ ):
41
+ """
42
+ Plot batch of images in rows, each column overlays a different mask from the dict.
43
+ Mask labels are shown as column titles. If mask name is 'per_pixel_omega', show it
44
+ directly with inferno colormap (no overlay).
45
+
46
+ Args:
47
+ images: np.ndarray, shape (batch, height, width, channels)
48
+ masks_dict: dict of {name: mask}, each mask shape (batch, height, width, channels)
49
+ mask_colors: dict of {name: color} or None (default colors used)
50
+ """
51
+ mask_names = list(masks_dict.keys())
52
+ batch_size = images.shape[0]
53
+ default_colors = ["red", "green", "#33aaff", "yellow", "magenta", "cyan"]
54
+ mask_colors = mask_colors or {
55
+ name: default_colors[i % len(default_colors)]
56
+ for i, name in enumerate(mask_names)
57
+ }
58
+
59
+ # Prepare images for each column
60
+ columns = []
61
+ cmaps = []
62
+ for name in mask_names:
63
+ if name == "per_pixel_omega":
64
+ mask_np = np.array(masks_dict[name])
65
+ columns.append(np.squeeze(mask_np))
66
+ cmaps.append(["inferno"] * batch_size)
67
+ else:
68
+ columns.append(np.squeeze(images))
69
+ cmaps.append(["gray"] * batch_size)
70
+
71
+ # Stack columns: shape (num_columns, batch, ...)
72
+ all_images = np.stack(columns, axis=0) # (num_columns, batch, ...)
73
+ # Rearrange to (batch, num_columns, ...)
74
+ all_images = (
75
+ np.transpose(all_images, (1, 0, 2, 3, 4))
76
+ if all_images.ndim == 5
77
+ else np.transpose(all_images, (1, 0, 2, 3))
78
+ )
79
+ # Flatten to (batch * num_columns, ...)
80
+ all_images = all_images.reshape(batch_size * len(mask_names), *images.shape[1:])
81
+
82
+ # Flatten cmaps for plot_image_grid in the same order as images
83
+ flat_cmaps = []
84
+ for row in range(batch_size):
85
+ for col in range(len(mask_names)):
86
+ flat_cmaps.append(cmaps[col][row])
87
+
88
+ fig, _ = plot_image_grid(
89
+ all_images,
90
+ ncols=len(mask_names),
91
+ remove_axis=False,
92
+ cmap=flat_cmaps,
93
+ figsize=(8, 3.3),
94
+ **kwargs,
95
+ )
96
+
97
+ # Overlay masks for non-per_pixel_omega columns
98
+ for col_idx, name in enumerate(mask_names):
99
+ if name == "per_pixel_omega":
100
+ continue
101
+ mask_np = np.array(masks_dict[name])
102
+ axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)]
103
+ for ax, mask_img in zip(axes, mask_np):
104
+ add_shape_from_mask(
105
+ ax, mask_img.squeeze(), color=mask_colors[name], alpha=0.3
106
+ )
107
+
108
+ # Add column titles
109
+ row_idx = 0
110
+ if titles is None:
111
+ titles = mask_names
112
+ for col_idx, name in enumerate(titles):
113
+ ax_idx = row_idx * len(mask_names) + col_idx
114
+ fig.axes[ax_idx].set_title(name, fontsize=9, color="white")
115
+ fig.axes[ax_idx].set_facecolor("black")
116
+
117
+ # Add colorbar for per_pixel_omega if present
118
+ if "per_pixel_omega" in mask_names:
119
+ col_idx = mask_names.index("per_pixel_omega")
120
+ axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)]
121
+
122
+ # Get vertical bounds of the subplot column
123
+ top_ax = axes[0]
124
+ bottom_ax = axes[-1]
125
+ top_pos = top_ax.get_position()
126
+ bottom_pos = bottom_ax.get_position()
127
+
128
+ full_y0 = bottom_pos.y0
129
+ full_y1 = top_pos.y1
130
+ full_height = full_y1 - full_y0
131
+
132
+ # Manually shrink to 80% of full height and center vertically
133
+ scale = 0.8
134
+ height = full_height * scale
135
+ y0 = full_y0 + (full_height - height) / 2
136
+
137
+ x0 = top_pos.x1 + 0.015 # Horizontal position to the right
138
+ width = 0.015 # Thin bar
139
+
140
+ # Add colorbar axis
141
+ cax = fig.add_axes([x0, y0, width, height])
142
+
143
+ im = axes[0].get_images()[0] if axes[0].get_images() else None
144
+ cbar = fig.colorbar(im, cax=cax)
145
+ cbar.set_label(r"Guidance weighting \mathbf{p}")
146
+ cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6))
147
+ cbar.ax.yaxis.set_tick_params(labelsize=7)
148
+ cbar.ax.yaxis.label.set_size(8)
149
+
150
+ return fig
151
+
152
+
153
+ def plot_dehazed_results(
154
+ hazy_images,
155
+ pred_tissue_images,
156
+ pred_haze_images,
157
+ diffusion_model,
158
+ titles=("Hazy", "Dehazed", "Haze"),
159
+ ):
160
+ """Create and save visualization with optional mask overlays."""
161
+
162
+ # Create the processed image stack using the helper function
163
+ input_shape = diffusion_model.input_shape
164
+ stack_images = ops.stack(
165
+ [
166
+ hazy_images,
167
+ pred_tissue_images,
168
+ pred_haze_images,
169
+ ]
170
+ )
171
+ stack_images = ops.reshape(stack_images, (-1, input_shape[0], input_shape[1]))
172
+
173
+ # Define labels based on what we're showing
174
+ fig, _ = plot_image_grid(
175
+ stack_images,
176
+ ncols=len(hazy_images),
177
+ remove_axis=False,
178
+ vmin=0,
179
+ vmax=255,
180
+ )
181
+ # Set labels and styling
182
+ for i, ax in enumerate(fig.axes):
183
+ if i % len(hazy_images) == 0:
184
+ label = titles[(i // len(hazy_images)) % len(titles)]
185
+ ax.set_ylabel(label, fontsize=12)
186
+
187
+ return fig
188
+
189
+
190
+ def plot_metrics(metrics, limits, out_path):
191
+ plt.style.use("seaborn-v0_8-darkgrid")
192
+ fig, axes = plt.subplots(1, len(metrics), figsize=(7.2, 2.7), dpi=600)
193
+ colors = ["#0057b7", "#ffb300", "#008744", "#d62d20"]
194
+ metric_labels = {
195
+ "CNR": r"CNR $\uparrow$",
196
+ "gCNR": r"gCNR $\uparrow$",
197
+ "KS_A": r"KS$_{septum}$ $\downarrow$",
198
+ "KS_B": r"KS$_{ventricle}$ $\uparrow$",
199
+ }
200
+ # For legend handles
201
+ legend_handles = []
202
+ import matplotlib.lines as mlines
203
+
204
+ min_style = {
205
+ "color": "crimson",
206
+ "linestyle": "--",
207
+ "lw": 1.2,
208
+ "marker": "o",
209
+ "markersize": 5,
210
+ }
211
+ max_style = {
212
+ "color": "crimson",
213
+ "linestyle": ":",
214
+ "lw": 1.2,
215
+ "marker": "s",
216
+ "markersize": 5,
217
+ }
218
+ for idx, (ax, (name, values)) in enumerate(zip(axes, metrics.items())):
219
+ ax.hist(
220
+ values,
221
+ bins=30,
222
+ color=colors[idx % len(colors)],
223
+ alpha=0.85,
224
+ edgecolor="black",
225
+ linewidth=0.7,
226
+ )
227
+ ax.set_xlabel(metric_labels.get(name, name), fontsize=11)
228
+ if idx == 0:
229
+ ax.set_ylabel("Count", fontsize=10)
230
+ # Draw limits and collect legend handles only once
231
+ if name in limits:
232
+ lims = limits[name]
233
+ if len(legend_handles) == 0:
234
+ # Only add legend handles for the first metric
235
+ min_handle = mlines.Line2D([], [], **min_style, label="min score")
236
+ max_handle = mlines.Line2D([], [], **max_style, label="max score")
237
+ legend_handles.extend([min_handle, max_handle])
238
+ if len(lims) > 0:
239
+ ax.axvline(lims[0], **min_style)
240
+ if len(lims) > 1:
241
+ ax.axvline(lims[1], **max_style)
242
+ ax.spines["top"].set_visible(False)
243
+ ax.spines["right"].set_visible(False)
244
+ ax.tick_params(axis="both", which="major", labelsize=9)
245
+ # Place legend above all subplots
246
+ fig.legend(
247
+ handles=legend_handles,
248
+ loc="upper center",
249
+ ncol=2,
250
+ fontsize=10,
251
+ frameon=False,
252
+ bbox_to_anchor=(0.5, 1.02),
253
+ )
254
+ return fig