import os import spaces import time from glob import glob from typing import Callable, Optional, Tuple, Union, Dict import random import matplotlib.pyplot as plt import numpy as np import torch import torchvision.transforms as transforms from PIL import Image from torch.utils.data import DataLoader from torchvision.datasets import VisionDataset from tqdm import tqdm from util.img_utils import clear_color from latent_models import PipelineWrapper def set_seed(seed: int) -> None: torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.cuda.manual_seed_all(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False class MinusOneToOne(torch.nn.Module): def forward(self, tensor: torch.Tensor) -> torch.Tensor: return tensor * 2 - 1 class ResizePIL(torch.nn.Module): def __init__(self, image_size: Optional[Union[int, Tuple[int, int]]] = None): super().__init__() if isinstance(image_size, int): image_size = (image_size, image_size) self.image_size = image_size def forward(self, pil_image: Image.Image) -> Image.Image: if self.image_size is not None: pil_image = pil_image.resize(self.image_size) return pil_image def get_loader(datadir: str, batch_size: int = 1, crop_to: Optional[Union[int, Tuple[int, int]]] = None, include_path: bool = False) -> DataLoader: transform = transforms.Compose([ ResizePIL(crop_to), transforms.ToTensor(), MinusOneToOne(), ]) loader = DataLoader(FoldersDataset(datadir, transform, include_path=include_path), batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False) return loader class FoldersDataset(VisionDataset): def __init__(self, root: str, transforms: Optional[Callable] = None, include_path: bool = False) -> None: super().__init__(root, transforms) self.include_path = include_path self.root = root if os.path.isdir(root): self.fpaths = glob(os.path.join(root, '**', '*.png'), recursive=True) self.fpaths += glob(os.path.join(root, '**', '*.JPEG'), recursive=True) self.fpaths += glob(os.path.join(root, '**', '*.jpg'), recursive=True) self.fpaths = sorted(self.fpaths) assert len(self.fpaths) > 0, "File list is empty. Check the root." elif os.path.exists(root): self.fpaths = [root] else: raise FileNotFoundError(f"File not found: {root}") def __len__(self): return len(self.fpaths) def __getitem__(self, index: int) -> Tuple[torch.Tensor, str]: fpath = self.fpaths[index] img = Image.open(fpath).convert('RGB') if self.transforms is not None: img = self.transforms(img) path = "" if self.include_path: dirname = os.path.dirname(fpath) # remove root from dirname path = dirname[len(self.root) + 1:] return img, os.path.basename(fpath).split(os.extsep)[0], path @spaces.GPU def compress(model: PipelineWrapper, img_to_compress: torch.Tensor, num_noises: int, loaded_indices, device, ): # model.set_timesteps(model.num_timesteps, device=device) dtype = model.dtype prompt_embeds = model.encode_prompt("", None) set_seed(88888888) if img_to_compress is None: img_to_compress = torch.zeros(1, 3, model.get_image_size(), model.get_image_size(), device=device) enc_im = model.encode_image(img_to_compress.to(dtype)) kwargs = model.get_pre_kwargs(height=img_to_compress.shape[-2], width=img_to_compress.shape[-1], prompt_embeds=prompt_embeds) set_seed(100000) xt = torch.randn(1, *enc_im.shape[1:], device=device, dtype=dtype) result_noise_indices = [] pbar = tqdm(model.timesteps) for idx, t in enumerate(pbar): set_seed(idx) noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) _, epst, _ = model.get_epst(xt, t, prompt_embeds, 0.0, **kwargs) x_0_hat = model.get_x_0_hat(xt, epst, t) if loaded_indices is None: if t >= 1: dot_prod = torch.matmul(noise.view(noise.shape[0], -1), (enc_im - x_0_hat).view(enc_im.shape[0], -1).transpose(0, 1)) best_idx = torch.argmax(dot_prod) best_noise = noise[best_idx] else: best_noise = noise[0] else: if t >= 1: best_idx = loaded_indices[idx] best_noise = noise[best_idx] else: best_noise = noise[0] if t >= 1: result_noise_indices.append(best_idx) xt = model.finish_step(xt, x_0_hat, epst, t, best_noise.unsqueeze(0), eta=None) try: img = model.decode_image(xt) except torch.OutOfMemoryError: img = model.decode_image(xt.to('cpu')) return img, torch.tensor(result_noise_indices).squeeze().cpu() @spaces.GPU def generate_ours(model: PipelineWrapper, num_noises: int, num_noises_to_optimize: int, prompt: str = "", negative_prompt: Optional[str] = None, indices = None, ) -> Tuple[torch.Tensor, torch.Tensor]: device = model.device dtype = model.dtype # print(num_noises, num_noises_to_optimize, flush=True) # model.set_timesteps(model.num_timesteps, device=device) set_seed(88888888) if prompt is None: prompt = "" prompt_embeds = model.encode_prompt(prompt, negative_prompt) kwargs = model.get_pre_kwargs(height=model.get_image_size(), width=model.get_image_size(), prompt_embeds=prompt_embeds) set_seed(100000) xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype) result_noise_indices = [] pbar = tqdm(model.timesteps) for idx, t in enumerate(pbar): set_seed(idx) noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) # Codebook _, epst_uncond, epst_cond = model.get_epst(xt, t, prompt_embeds, 1.0, return_everything=True, **kwargs) x_0_hat = model.get_x_0_hat(xt, epst_uncond, t) if t >= 1: if indices is None: prev_classif_score = epst_uncond - epst_cond set_seed(int(time.time_ns() & 0xFFFFFFFF)) noise_indices = torch.randint(0, num_noises, size=(num_noises_to_optimize,), device=device) loss = torch.matmul(noise[noise_indices].view(num_noises_to_optimize, -1), prev_classif_score.view(prev_classif_score.shape[0], -1).transpose(0, 1)) best_idx = noise_indices[torch.argmax(loss)] else: best_idx = indices[idx] best_noise = noise[best_idx] result_noise_indices.append(best_idx) else: best_noise = torch.zeros_like(noise[0]) xt = model.finish_step(xt, x_0_hat, epst_uncond, t, best_noise) try: img = model.decode_image(xt) except torch.OutOfMemoryError: img = model.decode_image(xt.to('cpu')) return img, torch.stack(result_noise_indices).squeeze().cpu() def decompress(model: PipelineWrapper, image_size: Tuple[int, int], indices: Dict[str, torch.Tensor], num_noises: int, prompt: str = "", negative_prompt: Optional[str] = None, tedit: int = 0, new_prompt: str = "", new_negative_prompt: Optional[str] = None, guidance_scale: float = 3.0, num_pursuit_noises: Optional[int] = 1, num_pursuit_coef_bits: Optional[int] = 3, t_range: Tuple[int, int] = (999, 0), robust_randn: bool = False ) -> torch.Tensor: noise_indices = indices['noise_indices'] coeffs_indices = indices['coeff_indices'] num_pursuit_noises = num_pursuit_noises if num_pursuit_noises is not None else 1 num_pursuit_coef_bits = num_pursuit_coef_bits if num_pursuit_coef_bits is not None else 1 device = model.device dtype = model.dtype # model.set_timesteps(model.num_timesteps, device=device) set_seed(88888888) orig_prompt_embeds = model.encode_prompt(prompt, negative_prompt) kwargs_orig = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1], prompt_embeds=orig_prompt_embeds) if new_prompt != prompt or new_negative_prompt != negative_prompt: new_prompt_embeds = model.encode_prompt(new_prompt, new_negative_prompt) kwargs_new = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1], prompt_embeds=new_prompt_embeds) else: new_prompt_embeds = orig_prompt_embeds kwargs_new = kwargs_orig set_seed(100000) xt = torch.randn(1, *model.get_latent_shape(image_size), device=device, dtype=dtype) pbar = tqdm(model.timesteps) for idx, t in enumerate(pbar): set_seed(idx) dont_optimize_t = not (t_range[0] >= t >= t_range[1]) # No intermittent support if robust_randn: noise = get_robust_randn(num_noises if not dont_optimize_t else 1, xt.shape[1:], device, dtype) else: noise = torch.randn(num_noises if not dont_optimize_t else 1, *xt.shape[1:], device=device, dtype=dtype) curr_embs = orig_prompt_embeds if idx < tedit else new_prompt_embeds curr_kwargs = kwargs_orig if idx < tedit else kwargs_new epst = model.get_epst(xt, t, curr_embs, guidance_scale, **curr_kwargs) x_0_hat = model.get_x_0_hat(xt, epst, t) curr_t_noise_indices = noise_indices[idx] best_noise = noise[curr_t_noise_indices[0]] pursuit_coefs = torch.linspace(0, 1, 2 ** num_pursuit_coef_bits + 1)[1:] if num_pursuit_noises > 1: curr_t_coeffs_indices = coeffs_indices[idx] if curr_t_coeffs_indices[0] == -1: continue for pursuit_idx in range(1, num_pursuit_noises): pursuit_coef = pursuit_coefs[curr_t_coeffs_indices[pursuit_idx]] best_noise = best_noise * torch.sqrt(pursuit_coef) + noise[ curr_t_noise_indices[pursuit_idx]] * torch.sqrt(1 - pursuit_coef) best_noise /= best_noise.std() best_noise = best_noise.unsqueeze(0) xt = model.finish_step(xt, x_0_hat, epst, t, best_noise) img = model.decode_image(xt) return img def inf_generate(model: PipelineWrapper, prompt: str = "", negative_prompt: Optional[str] = None, guidance_scale: float = 7.0, record: int = 0, save_root: str = "") -> Tuple[torch.Tensor, torch.Tensor]: device = model.device dtype = model.dtype model.set_timesteps(model.num_timesteps, device=device) prompt_embeds = model.encode_prompt(prompt, negative_prompt) kwargs = model.get_pre_kwargs(height=model.get_image_size(), width=model.get_image_size(), prompt_embeds=prompt_embeds) xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype) pbar = tqdm(model.timesteps) for idx, t in enumerate(pbar): noise = torch.randn(1, *xt.shape[1:], device=device, dtype=dtype) epst = model.get_epst(xt, t, prompt_embeds, guidance_scale, **kwargs) x_0_hat = model.get_x_0_hat(xt, epst, t) xt = model.finish_step(xt, x_0_hat, epst, t, noise) if record and not idx % record: img = model.decode_image(x_0_hat) plt.imsave(os.path.join(save_root, f"progress/x_0_hat_{str(t.item()).zfill(4)}.png"), clear_color(img[0].unsqueeze(0), normalize=False)) try: img = model.decode_image(xt) except torch.OutOfMemoryError: img = model.decode_image(xt.to('cpu')) return img