|
import json |
|
from io import BytesIO |
|
from pathlib import Path |
|
from typing import Any, Dict, List |
|
|
|
import keras |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import tyro |
|
from keras import ops |
|
from matplotlib.patches import PathPatch |
|
from matplotlib.path import Path as pltPath |
|
from PIL import Image |
|
from skimage import measure |
|
from zea import log |
|
from zea.utils import save_to_gif |
|
from zea.visualize import plot_image_grid |
|
|
|
from utils import postprocess |
|
|
|
|
|
def add_shape_from_mask(ax, mask, **kwargs): |
|
"""add a shape to axis from mask array. |
|
|
|
Args: |
|
ax (plt.ax): matplotlib axis |
|
mask (ndarray): numpy array with non-zero |
|
shape defining the region of interest. |
|
Kwargs: |
|
edgecolor (str): color of the shape's edge |
|
facecolor (str): color of the shape's face |
|
linewidth (int): width of the shape's edge |
|
|
|
Returns: |
|
plt.ax: matplotlib axis with shape added |
|
""" |
|
|
|
padded_mask = np.pad(mask, pad_width=1, mode="constant", constant_values=0) |
|
contours = measure.find_contours(padded_mask, 0.5) |
|
patches = [] |
|
for contour in contours: |
|
|
|
contour -= 1 |
|
path = pltPath(contour[:, ::-1]) |
|
patch = PathPatch(path, **kwargs) |
|
patches.append(ax.add_patch(patch)) |
|
return patches |
|
|
|
|
|
def matplotlib_figure_to_numpy(fig): |
|
"""Convert matplotlib figure to numpy array. |
|
|
|
Args: |
|
fig (matplotlib.figure.Figure): figure to convert. |
|
|
|
Returns: |
|
np.ndarray: numpy array of figure. |
|
|
|
""" |
|
buf = BytesIO() |
|
fig.savefig(buf, format="png", bbox_inches="tight") |
|
buf.seek(0) |
|
image = Image.open(buf).convert("RGB") |
|
image = np.array(image)[..., :3] |
|
buf.close() |
|
return image |
|
|
|
|
|
def plot_batch_with_named_masks( |
|
images, masks_dict, mask_colors=None, titles=None, **kwargs |
|
): |
|
""" |
|
Plot batch of images in rows, each column overlays a different mask from the dict. |
|
Mask labels are shown as column titles. If mask name is 'per_pixel_omega', show it |
|
directly with inferno colormap (no overlay). |
|
|
|
Args: |
|
images: np.ndarray, shape (batch, height, width, channels) |
|
masks_dict: dict of {name: mask}, each mask shape (batch, height, width, channels) |
|
mask_colors: dict of {name: color} or None (default colors used) |
|
""" |
|
mask_names = list(masks_dict.keys()) |
|
batch_size = images.shape[0] |
|
default_colors = ["red", "green", "#33aaff", "yellow", "magenta", "cyan"] |
|
mask_colors = mask_colors or { |
|
name: default_colors[i % len(default_colors)] |
|
for i, name in enumerate(mask_names) |
|
} |
|
|
|
|
|
columns = [] |
|
cmaps = [] |
|
for name in mask_names: |
|
if name == "per_pixel_omega": |
|
mask_np = np.array(masks_dict[name]) |
|
columns.append(np.squeeze(mask_np)) |
|
cmaps.append(["inferno"] * batch_size) |
|
else: |
|
columns.append(np.squeeze(images)) |
|
cmaps.append(["gray"] * batch_size) |
|
|
|
|
|
all_images = np.stack(columns, axis=0) |
|
|
|
all_images = ( |
|
np.transpose(all_images, (1, 0, 2, 3, 4)) |
|
if all_images.ndim == 5 |
|
else np.transpose(all_images, (1, 0, 2, 3)) |
|
) |
|
|
|
all_images = all_images.reshape(batch_size * len(mask_names), *images.shape[1:]) |
|
|
|
|
|
flat_cmaps = [] |
|
for row in range(batch_size): |
|
for col in range(len(mask_names)): |
|
flat_cmaps.append(cmaps[col][row]) |
|
|
|
fig, _ = plot_image_grid( |
|
all_images, |
|
ncols=len(mask_names), |
|
remove_axis=False, |
|
cmap=flat_cmaps, |
|
figsize=(8, 3.3), |
|
**kwargs, |
|
) |
|
|
|
|
|
for col_idx, name in enumerate(mask_names): |
|
if name == "per_pixel_omega": |
|
continue |
|
mask_np = np.array(masks_dict[name]) |
|
axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)] |
|
for ax, mask_img in zip(axes, mask_np): |
|
add_shape_from_mask( |
|
ax, mask_img.squeeze(), color=mask_colors[name], alpha=0.3 |
|
) |
|
|
|
|
|
row_idx = 0 |
|
if titles is None: |
|
titles = mask_names |
|
for col_idx, name in enumerate(titles): |
|
ax_idx = row_idx * len(mask_names) + col_idx |
|
fig.axes[ax_idx].set_title(name, fontsize=9, color="white") |
|
fig.axes[ax_idx].set_facecolor("black") |
|
|
|
|
|
if "per_pixel_omega" in mask_names: |
|
col_idx = mask_names.index("per_pixel_omega") |
|
axes = fig.axes[col_idx : batch_size * len(mask_names) : len(mask_names)] |
|
|
|
|
|
top_ax = axes[0] |
|
bottom_ax = axes[-1] |
|
top_pos = top_ax.get_position() |
|
bottom_pos = bottom_ax.get_position() |
|
|
|
full_y0 = bottom_pos.y0 |
|
full_y1 = top_pos.y1 |
|
full_height = full_y1 - full_y0 |
|
|
|
|
|
scale = 0.8 |
|
height = full_height * scale |
|
y0 = full_y0 + (full_height - height) / 2 |
|
|
|
x0 = top_pos.x1 + 0.015 |
|
width = 0.015 |
|
|
|
|
|
cax = fig.add_axes([x0, y0, width, height]) |
|
|
|
im = axes[0].get_images()[0] if axes[0].get_images() else None |
|
cbar = fig.colorbar(im, cax=cax) |
|
cbar.set_label(r"Guidance weighting $\mathbf{p}$") |
|
cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(nbins=6)) |
|
cbar.ax.yaxis.set_tick_params(labelsize=7) |
|
cbar.ax.yaxis.label.set_size(8) |
|
|
|
return fig |
|
|
|
|
|
def plot_dehazed_results( |
|
hazy_images, |
|
pred_tissue_images, |
|
pred_haze_images, |
|
diffusion_model, |
|
titles=("Hazy", "Dehazed", "Haze"), |
|
): |
|
"""Create and save visualization with optional mask overlays.""" |
|
|
|
|
|
input_shape = diffusion_model.input_shape |
|
stack_images = ops.stack( |
|
[ |
|
hazy_images, |
|
pred_tissue_images, |
|
pred_haze_images, |
|
] |
|
) |
|
stack_images = ops.reshape(stack_images, (-1, input_shape[0], input_shape[1])) |
|
|
|
|
|
fig, _ = plot_image_grid( |
|
stack_images, |
|
ncols=len(hazy_images), |
|
remove_axis=False, |
|
vmin=0, |
|
vmax=255, |
|
) |
|
|
|
for i, ax in enumerate(fig.axes): |
|
if i % len(hazy_images) == 0: |
|
label = titles[(i // len(hazy_images)) % len(titles)] |
|
ax.set_ylabel(label, fontsize=12) |
|
|
|
return fig |
|
|
|
|
|
def plot_metrics(metrics, limits, out_path): |
|
plt.style.use("seaborn-v0_8-darkgrid") |
|
fig, axes = plt.subplots(1, len(metrics), figsize=(7.2, 2.7), dpi=200) |
|
colors = ["#0057b7", "#ffb300", "#008744", "#d62d20"] |
|
metric_labels = { |
|
"CNR": r"CNR $\uparrow$", |
|
"gCNR": r"gCNR $\uparrow$", |
|
"KS_A": r"KS$_{septum}$ $\downarrow$", |
|
"KS_B": r"KS$_{ventricle}$ $\uparrow$", |
|
} |
|
|
|
legend_handles = [] |
|
import matplotlib.lines as mlines |
|
|
|
min_style = { |
|
"color": "crimson", |
|
"linestyle": "--", |
|
"lw": 1.2, |
|
"marker": "o", |
|
"markersize": 5, |
|
} |
|
max_style = { |
|
"color": "crimson", |
|
"linestyle": ":", |
|
"lw": 1.2, |
|
"marker": "s", |
|
"markersize": 5, |
|
} |
|
for idx, (ax, (name, values)) in enumerate(zip(axes, metrics.items())): |
|
ax.hist( |
|
values, |
|
bins=30, |
|
color=colors[idx % len(colors)], |
|
alpha=0.85, |
|
edgecolor="black", |
|
linewidth=0.7, |
|
) |
|
ax.set_xlabel(metric_labels.get(name, name), fontsize=11) |
|
if idx == 0: |
|
ax.set_ylabel("Count", fontsize=10) |
|
|
|
if name in limits: |
|
lims = limits[name] |
|
if len(legend_handles) == 0: |
|
|
|
min_handle = mlines.Line2D([], [], **min_style, label="min score") |
|
max_handle = mlines.Line2D([], [], **max_style, label="max score") |
|
legend_handles.extend([min_handle, max_handle]) |
|
if len(lims) > 0: |
|
ax.axvline(lims[0], **min_style) |
|
if len(lims) > 1: |
|
ax.axvline(lims[1], **max_style) |
|
ax.spines["top"].set_visible(False) |
|
ax.spines["right"].set_visible(False) |
|
ax.tick_params(axis="both", which="major", labelsize=9) |
|
|
|
fig.legend( |
|
handles=legend_handles, |
|
loc="upper center", |
|
ncol=2, |
|
fontsize=10, |
|
frameon=False, |
|
bbox_to_anchor=(0.5, 1.02), |
|
) |
|
return fig |
|
|
|
|
|
def plot_optimization_history_from_json( |
|
trials_data: List[Dict[str, Any]], output_path: Path, method: str |
|
): |
|
"""Plot optimization history directly from JSON data.""" |
|
|
|
|
|
completed_trials = [ |
|
t for t in trials_data if t["state"] == "COMPLETE" and t["value"] is not None |
|
] |
|
|
|
if not completed_trials: |
|
print("No completed trials found!") |
|
return |
|
|
|
|
|
completed_trials.sort(key=lambda x: x["number"]) |
|
|
|
trial_numbers = [t["number"] for t in completed_trials] |
|
values = [t["value"] for t in completed_trials] |
|
|
|
|
|
plt.style.use("seaborn-v0_8-darkgrid") |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(5, 3), dpi=600) |
|
|
|
|
|
ax.scatter( |
|
trial_numbers, |
|
values, |
|
c="#0057b7", |
|
alpha=0.6, |
|
s=30, |
|
edgecolor="black", |
|
linewidth=0.5, |
|
) |
|
|
|
|
|
best_values = [] |
|
current_best = values[0] |
|
for val in values: |
|
if val > current_best: |
|
current_best = val |
|
best_values.append(current_best) |
|
|
|
ax.plot( |
|
trial_numbers, |
|
best_values, |
|
color="#d62d20", |
|
linewidth=2.5, |
|
label="Best Value", |
|
marker="o", |
|
markersize=4, |
|
markevery=max(1, len(trial_numbers) // 20), |
|
) |
|
|
|
ax.set_xlabel("Trial", fontsize=11) |
|
ax.set_ylabel("Objective Value", fontsize=11) |
|
|
|
ax.legend(fontsize=10, frameon=False) |
|
|
|
|
|
ax.spines["top"].set_visible(False) |
|
ax.spines["right"].set_visible(False) |
|
ax.tick_params(axis="both", which="major", labelsize=9) |
|
|
|
|
|
fig.savefig( |
|
output_path / f"optimization_history_{method}.png", dpi=600, bbox_inches="tight" |
|
) |
|
fig.savefig( |
|
output_path / f"optimization_history_{method}.pdf", dpi=600, bbox_inches="tight" |
|
) |
|
plt.close(fig) |
|
|
|
|
|
def create_animation_frame(hazy_images, tissue_frame, haze_frame): |
|
"""Create a single animation frame from the tracked progress.""" |
|
batch, height, width = ops.shape(hazy_images) |
|
frame_stack = ops.stack( |
|
[ |
|
hazy_images, |
|
tissue_frame, |
|
haze_frame, |
|
] |
|
) |
|
frame_stack = ops.reshape(frame_stack, (-1, height, width)) |
|
fig_frame, _ = plot_image_grid( |
|
frame_stack, |
|
ncols=len(hazy_images), |
|
remove_axis=False, |
|
vmin=0, |
|
vmax=255, |
|
) |
|
labels = ["Hazy", "Haze", "Tissue"] |
|
for i, ax in enumerate(fig_frame.axes): |
|
label = labels[i % len(labels)] |
|
ax.set_ylabel(label, fontsize=12) |
|
frame_array = matplotlib_figure_to_numpy(fig_frame) |
|
plt.close(fig_frame) |
|
return frame_array |
|
|
|
|
|
def create_animation(hazy_images, diffusion_model, output_path, fps): |
|
"""Create animation from tracked progress frames.""" |
|
if not (len(diffusion_model.track_progress) > 1): |
|
log.warning( |
|
"Animation requested but no intermediate frames were tracked. " |
|
"Try reducing diffusion_steps or ensure progress tracking is enabled." |
|
) |
|
return |
|
|
|
log.info(f"Creating animation with {len(diffusion_model.track_progress)} frames...") |
|
|
|
animation_frames = [] |
|
progbar = keras.utils.Progbar( |
|
len(diffusion_model.track_progress), unit_name="frame" |
|
) |
|
for tissue_frame in diffusion_model.track_progress: |
|
haze_frame = hazy_images - tissue_frame - 1 |
|
tissue_frame = postprocess(tissue_frame, diffusion_model.input_range) |
|
haze_frame = postprocess(haze_frame, diffusion_model.input_range) |
|
_hazy_images = postprocess(hazy_images, diffusion_model.input_range) |
|
frame_array = create_animation_frame(_hazy_images, tissue_frame, haze_frame) |
|
animation_frames.append(frame_array) |
|
progbar.add(1) |
|
|
|
Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
|
animation_path = Path(output_path).with_suffix(".gif") |
|
save_to_gif(animation_frames, animation_path, fps=fps) |
|
|
|
|
|
def main(json_file: str, output_dir: str = "plots", method: str = "semantic_dps"): |
|
json_path = Path(json_file) |
|
if not json_path.exists(): |
|
raise FileNotFoundError(f"JSON file not found: {json_file}") |
|
|
|
|
|
with open(json_path, "r") as f: |
|
trials_data = json.load(f) |
|
|
|
print(f"Loaded {len(trials_data)} trials from {json_file}") |
|
|
|
|
|
output_path = Path(output_dir) |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
print("Generating optimization history plot...") |
|
plot_optimization_history_from_json(trials_data, output_path, method) |
|
|
|
|
|
if __name__ == "__main__": |
|
tyro.cli(main) |
|
|