|
import os |
|
import numpy |
|
import torch |
|
import random |
|
from PIL import Image |
|
import cv2 |
|
from diffusers import DDPMPipeline |
|
from diffusers.utils import make_image_grid |
|
import matplotlib.pyplot as plt |
|
import numpy |
|
from skimage.color import lab2rgb |
|
|
|
|
|
def evaluate(condition_images, config, epoch, pipeline): |
|
|
|
|
|
|
|
test_dir = os.path.join(config.output_dir, "samples") |
|
os.makedirs(test_dir, exist_ok=True) |
|
|
|
list_images = pipeline( |
|
condition_images=condition_images, |
|
batch_size=config.eval_batch_size, |
|
output_type="numpy", |
|
generator=torch.manual_seed(config.seed), |
|
).images |
|
list_images = list_images * numpy.array([[[[100, 127, 127]]]]) |
|
|
|
list_images = numpy.concatenate( |
|
( |
|
numpy.clip(list_images[..., :1], a_min=0, a_max=100), |
|
numpy.clip(list_images[..., 1:], a_min=-128, a_max=127), |
|
), |
|
axis=3, |
|
) |
|
list_images = lab2rgb(list_images, channel_axis=3) |
|
list_images = (list_images * 255).astype(numpy.uint8) |
|
for i, image in enumerate(list_images): |
|
cv2.imwrite(f"{test_dir}/{epoch:04d}_{i:02d}_rgb.png", image) |
|
|
|
|
|
def less_noisy_image(sample, timestep, betas, prev_timestep, model_output): |
|
alphas = 1.0 - betas |
|
alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
alpha_prod_t = alphas_cumprod[timestep] |
|
|
|
alpha_prod_t_prev = ( |
|
(torch.where(prev_timestep > 0, alphas_cumprod[prev_timestep], 1.0)) |
|
.unsqueeze(-1) |
|
.unsqueeze(-1) |
|
.unsqueeze(-1) |
|
) |
|
|
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
|
alpha_prod_t = alpha_prod_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
beta_prod_t = beta_prod_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
|
|
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample |
|
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) |
|
model_output_denom_coeff = alpha_prod_t.view(-1) * beta_prod_t_prev.view(-1) ** ( |
|
0.5 |
|
) + (alpha_prod_t.view(-1) * beta_prod_t.view(-1) * alpha_prod_t_prev.view(-1)) ** ( |
|
0.5 |
|
) |
|
|
|
|
|
prev_sample = sample_coeff * sample - ( |
|
alpha_prod_t_prev - alpha_prod_t |
|
) * model_output / model_output_denom_coeff.unsqueeze(-1).unsqueeze(-1).unsqueeze( |
|
-1 |
|
) |
|
|
|
return prev_sample |
|
|
|
|
|
def make_grid(images, rows, cols): |
|
w, h = images[0].size |
|
grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
for i, image in enumerate(images): |
|
grid.paste(image, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|