Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import os | |
from pathlib import Path | |
os.environ["KERAS_BACKEND"] = "jax" | |
import jax | |
import keras | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy | |
import tyro | |
import zea | |
from keras import ops | |
from PIL import Image | |
from skimage import filters | |
from zea import Config, init_device, log | |
from zea.internal.operators import Operator | |
from zea.models.diffusion import ( | |
DPS, | |
DiffusionModel, | |
diffusion_guidance_registry, | |
) | |
from zea.tensor_ops import L2 | |
from zea.utils import translate | |
from plots import create_animation, plot_batch_with_named_masks, plot_dehazed_results | |
from utils import ( | |
apply_bottom_preservation, | |
extract_skeleton, | |
load_image, | |
postprocess, | |
preprocess, | |
smooth_L1, | |
) | |
class IdentityOperator(Operator): | |
def forward(self, data): | |
return data | |
def __str__(self): | |
return "y = x" | |
class SemanticDPS(DPS): | |
def __init__( | |
self, | |
diffusion_model, | |
segmentation_model, | |
operator, | |
disable_jit=False, | |
**kwargs, | |
): | |
"""Initialize the diffusion guidance. | |
Args: | |
diffusion_model: The diffusion model to use for guidance. | |
operator: The forward (measurement) operator to use for guidance. | |
disable_jit: Whether to disable JIT compilation. | |
""" | |
self.diffusion_model = diffusion_model | |
self.segmentation_model = segmentation_model | |
self.operator = operator | |
self.disable_jit = disable_jit | |
self.setup(**kwargs) | |
def _get_fixed_mask( | |
self, | |
images, | |
bottom_px=40, | |
top_px=20, | |
): | |
batch_size, height, width, channels = ops.shape(images) | |
# Create row indices for each pixel | |
row_indices = ops.arange(height) | |
row_indices = ops.reshape(row_indices, (height, 1)) | |
row_indices = ops.tile(row_indices, (1, width)) | |
# Create top row mask | |
fixed_mask = ops.where( | |
ops.logical_or(row_indices < top_px, row_indices >= height - bottom_px), | |
1.0, | |
0.0, | |
) | |
fixed_mask = ops.expand_dims(fixed_mask, axis=0) | |
fixed_mask = ops.expand_dims(fixed_mask, axis=-1) | |
fixed_mask = ops.tile(fixed_mask, (batch_size, 1, 1, channels)) | |
return fixed_mask | |
def _get_segmentation_mask(self, images, threshold, sigma): | |
input_range = self.diffusion_model.input_range | |
images = ops.clip(images, input_range[0], input_range[1]) | |
images = translate(images, input_range, (-1, 1)) | |
masks = self.segmentation_model(images) | |
mask_vent = masks[..., 0] # ROI 1 ventricle | |
mask_sept = masks[..., 1] # ROI 2 septum | |
def _preprocess_mask(mask): | |
mask = ops.convert_to_numpy(mask) | |
mask = np.expand_dims(mask, axis=-1) | |
mask = np.where(mask > threshold, 1.0, 0.0) | |
mask = filters.gaussian(mask, sigma=sigma) | |
mask = (mask - ops.min(mask)) / (ops.max(mask) - ops.min(mask) + 1e-8) | |
return mask | |
mask_vent = _preprocess_mask(mask_vent) | |
mask_sept = _preprocess_mask(mask_sept) | |
return mask_vent, mask_sept | |
def _get_dark_mask(self, images): | |
min_val = self.diffusion_model.input_range[0] | |
dark_mask = ops.where(ops.abs(images - min_val) < 1e-6, 1.0, 0.0) | |
return dark_mask | |
def make_omega_map( | |
self, images, mask_params, fixed_mask_params, skeleton_params, guidance_kwargs | |
): | |
masks = self.get_masks(images, mask_params, fixed_mask_params, skeleton_params) | |
masks_vent = masks["vent"] | |
masks_sept = masks["sept"] | |
masks_fixed = masks["fixed"] | |
masks_skeleton = masks["skeleton"] | |
masks_dark = masks["dark"] | |
masks_strong = ops.clip( | |
masks_sept + masks_fixed + masks_skeleton + masks_dark, 0, 1 | |
) | |
background = ops.where(masks_strong < 0.1, 1.0, 0.0) * ops.where( | |
masks_vent == 0, 1.0, 0.0 | |
) | |
masks_vent_filtered = masks_vent * (1.0 - masks_strong) | |
per_pixel_omega = ( | |
guidance_kwargs["omega"] * background | |
+ guidance_kwargs["omega_vent"] * masks_vent_filtered | |
+ guidance_kwargs["omega_sept"] * masks_strong | |
) | |
haze_mask_components = (masks_vent > 0.5) * (1 - masks_strong > 0.5) | |
haze_mask = [] | |
for i, m in enumerate(haze_mask_components): | |
if scipy.ndimage.label(m)[1] > 1: | |
# masks_strong _splits_ masks_vent in 2 or more components | |
# so we fall back to masks_vent | |
haze_mask.append(masks_vent[i]) | |
# also remove guidance from this region to avoid bringing haze in | |
per_pixel_omega = per_pixel_omega.at[i].set( | |
per_pixel_omega[i] * (1 - masks_vent[i]) | |
) | |
else: | |
# masks_strong 'shaves off' some of masks_vent, | |
# where there is tissue | |
haze_mask.append((masks_vent * (1 - masks_strong))[i]) | |
haze_mask = ops.stack(haze_mask, axis=0) | |
masks["per_pixel_omega"] = per_pixel_omega | |
masks["haze"] = haze_mask | |
return masks | |
def get_masks(self, images, mask_params, fixed_mask_params, skeleton_params): | |
"""Generate a mask from the input images.""" | |
masks_vent, masks_sept = self._get_segmentation_mask(images, **mask_params) | |
masks_fixed = self._get_fixed_mask(images, **fixed_mask_params) | |
masks_skeleton = extract_skeleton( | |
images, self.diffusion_model.input_range, **skeleton_params | |
) | |
masks_dark = self._get_dark_mask(images) | |
return { | |
"vent": masks_vent, | |
"sept": masks_sept, | |
"fixed": masks_fixed, | |
"skeleton": masks_skeleton, | |
"dark": masks_dark, | |
} | |
def compute_error( | |
self, | |
noisy_images, | |
measurements, | |
noise_rates, | |
signal_rates, | |
per_pixel_omega, | |
haze_mask, | |
eta=0.01, | |
smooth_l1_beta=0.5, | |
**kwargs, | |
): | |
"""Compute measurement error for diffusion posterior sampling. | |
Args: | |
noisy_images: Noisy images. | |
measurement: Target measurement. | |
operator: Forward operator. | |
noise_rates: Current noise rates. | |
signal_rates: Current signal rates. | |
omega: Weight for the measurement error. | |
omega_mask: Weight for the measurement error at the mask region. | |
omega_haze_prior: Weight for the haze prior penalty. | |
**kwargs: Additional arguments for the operator. | |
Returns: | |
Tuple of (measurement_error, (pred_noises, pred_images)) | |
""" | |
pred_noises, pred_images = self.diffusion_model.denoise( | |
noisy_images, | |
noise_rates, | |
signal_rates, | |
training=False, | |
) | |
measurement_error = L2( | |
per_pixel_omega | |
* (measurements - self.operator.forward(pred_images, **kwargs)) | |
) | |
hazy_pixels = pred_images * haze_mask | |
# L1 penalty on haze pixels | |
# add +1 to make -1 (=black) the 'sparse' value | |
haze_prior_error = smooth_L1(hazy_pixels + 1, beta=smooth_l1_beta) | |
total_error = measurement_error + eta * haze_prior_error | |
return total_error, (pred_noises, pred_images) | |
def init(config): | |
"""Initialize models, operator, and guidance objects for semantic-dps dehazing.""" | |
operator = IdentityOperator() | |
diffusion_model = DiffusionModel.from_preset( | |
config.diffusion_model_path, | |
) | |
log.success( | |
f"Diffusion model loaded from {log.yellow(config.diffusion_model_path)}" | |
) | |
segmentation_model = load_segmentation_model(config.segmentation_model_path) | |
log.success( | |
f"Segmentation model loaded from {log.yellow(config.segmentation_model_path)}" | |
) | |
guidance_fn = SemanticDPS( | |
diffusion_model=diffusion_model, | |
segmentation_model=segmentation_model, | |
operator=operator, | |
) | |
diffusion_model._init_operator_and_guidance(operator, guidance_fn) | |
return diffusion_model | |
def load_segmentation_model(path): | |
"""Load segmentation model""" | |
segmentation_model = keras.saving.load_model(path) | |
return segmentation_model | |
def run( | |
hazy_images: any, | |
diffusion_model: DiffusionModel, | |
seed, | |
guidance_kwargs: dict, | |
mask_params: dict, | |
fixed_mask_params: dict, | |
skeleton_params: dict, | |
batch_size: int = 4, | |
diffusion_steps: int = 100, | |
threshold_output_quantile: float = None, | |
preserve_bottom_percent: float = 30.0, | |
bottom_transition_width: float = 10.0, | |
verbose: bool = True, | |
): | |
input_range = diffusion_model.input_range | |
hazy_images = preprocess(hazy_images, normalization_range=input_range) | |
pred_tissue_images = [] | |
masks_out = [] | |
num_images = hazy_images.shape[0] | |
num_batches = (num_images + batch_size - 1) // batch_size | |
progbar = keras.utils.Progbar(num_batches, verbose=verbose, unit_name="batch") | |
i = 0 | |
batch_idx = 0 | |
for i in range(num_batches): | |
batch = hazy_images[i * batch_size : (i * batch_size) + batch_size] | |
masks = diffusion_model.guidance_fn.make_omega_map( | |
batch, mask_params, fixed_mask_params, skeleton_params, guidance_kwargs | |
) | |
batch_images = diffusion_model.posterior_sample( | |
batch, | |
n_samples=1, | |
n_steps=diffusion_steps, | |
seed=seed, | |
verbose=True, | |
per_pixel_omega=masks["per_pixel_omega"], | |
haze_mask=masks["haze"], | |
eta=guidance_kwargs["eta"], | |
smooth_l1_beta=guidance_kwargs["smooth_l1_beta"], | |
) | |
batch_images = ops.take(batch_images, 0, axis=1) | |
pred_tissue_images.append(batch_images) | |
masks_out.append(masks) | |
batch_idx += 1 | |
progbar.update(batch_idx) | |
i += batch_size | |
pred_tissue_images = ops.concatenate(pred_tissue_images, axis=0) | |
masks_out = { | |
key: ops.concatenate([m[key] for m in masks_out], axis=0) | |
for key in masks_out[0].keys() | |
} | |
pred_haze_images = hazy_images - pred_tissue_images - 1 | |
if threshold_output_quantile is not None: | |
threshold_value = ops.quantile( | |
pred_tissue_images, threshold_output_quantile, axis=(1, 2), keepdims=True | |
) | |
pred_tissue_images = ops.where( | |
pred_tissue_images < threshold_value, input_range[0], pred_tissue_images | |
) | |
# Apply bottom preservation with smooth transition | |
if preserve_bottom_percent > 0: | |
pred_tissue_images = apply_bottom_preservation( | |
pred_tissue_images, | |
hazy_images, | |
preserve_bottom_percent=preserve_bottom_percent, | |
transition_width=bottom_transition_width, | |
) | |
pred_tissue_images = postprocess(pred_tissue_images, input_range) | |
hazy_images = postprocess(hazy_images, input_range) | |
pred_haze_images = postprocess(pred_haze_images, input_range) | |
return hazy_images, pred_tissue_images, pred_haze_images, masks_out | |
def main( | |
input_folder: str = "./assets", | |
output_folder: str = "./temp", | |
num_imgs_plot: int = 5, | |
device: str = "auto:1", | |
config: str = "configs/semantic_dps.yaml", | |
): | |
num_img = num_imgs_plot | |
zea.visualize.set_mpl_style() | |
init_device(device) | |
config = Config.from_yaml(config) | |
seed = jax.random.PRNGKey(config.seed) | |
paths = list(Path(input_folder).glob("*.png")) | |
paths = sorted(paths) | |
output_folder = Path(output_folder) | |
images = [] | |
for path in paths: | |
image = load_image(path) | |
images.append(image) | |
images = ops.stack(images, axis=0) | |
diffusion_model = init(config) | |
hazy_images, pred_tissue_images, pred_haze_images, masks = run( | |
images, | |
diffusion_model=diffusion_model, | |
seed=seed, | |
**config.params, | |
) | |
output_folder.mkdir(parents=True, exist_ok=True) | |
for image, path in zip(pred_tissue_images, paths): | |
image = ops.convert_to_numpy(image) | |
file_name = path.name | |
Image.fromarray(image).save(output_folder / file_name) | |
fig = plot_dehazed_results( | |
hazy_images[:num_img], | |
pred_tissue_images[:num_img], | |
pred_haze_images[:num_img], | |
diffusion_model, | |
titles=[ | |
r"Hazy $\mathbf{y}$", | |
r"Dehazed $\mathbf{\hat{x}}$", | |
r"Haze $\mathbf{\hat{h}}$", | |
], | |
) | |
path = Path("dehazed_results.png") | |
save_kwargs = {"bbox_inches": "tight", "dpi": 300} | |
fig.savefig(path, **save_kwargs) | |
fig.savefig(path.with_suffix(".pdf"), **save_kwargs) | |
log.success(f"Segmentation steps saved to {log.yellow(path)}") | |
masks_viz = copy.deepcopy(masks) | |
masks_viz.pop("haze") | |
num_img = 2 # hardcoded as the plotting figure only neatly supports 2 rows | |
masks_viz = {k: v[:num_img] for k, v in masks_viz.items()} | |
fig = plot_batch_with_named_masks( | |
images[:num_img], | |
masks_viz, | |
titles=[ | |
r"Ventricle $v(\mathbf{y})$", | |
r"Septum $s(\mathbf{y})$", | |
r"Fixed", | |
r"Skeleton $t(\mathbf{y})$", | |
r"Dark $b(\mathbf{y})$", | |
r"Guidance $d(\mathbf{y})$", | |
], | |
) | |
path = Path("segmentation_steps.png") | |
fig.savefig(path, **save_kwargs) | |
fig.savefig(path.with_suffix(".pdf"), **save_kwargs) | |
log.success(f"Segmentation steps saved to {log.yellow(path)}") | |
last_batch_size = len(diffusion_model.track_progress[0]) | |
create_animation( | |
preprocess(hazy_images[-last_batch_size:], diffusion_model.input_range), | |
diffusion_model, | |
output_path="animation.gif", | |
fps=10, | |
) | |
plt.close("all") | |
if __name__ == "__main__": | |
tyro.cli(main) | |