Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import time | |
from glob import glob | |
from typing import Callable, Optional, Tuple, Union, Dict | |
import random | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from torchvision.datasets import VisionDataset | |
from tqdm import tqdm | |
from util.img_utils import clear_color | |
from latent_models import PipelineWrapper | |
def set_seed(seed: int) -> None: | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# torch.backends.cudnn.deterministic = True | |
# torch.backends.cudnn.benchmark = False | |
class MinusOneToOne(torch.nn.Module): | |
def forward(self, tensor: torch.Tensor) -> torch.Tensor: | |
return tensor * 2 - 1 | |
class ResizePIL(torch.nn.Module): | |
def __init__(self, image_size: Optional[Union[int, Tuple[int, int]]] = None): | |
super().__init__() | |
if isinstance(image_size, int): | |
image_size = (image_size, image_size) | |
self.image_size = image_size | |
def forward(self, pil_image: Image.Image) -> Image.Image: | |
if self.image_size is not None: | |
pil_image = pil_image.resize(self.image_size) | |
return pil_image | |
def get_loader(datadir: str, batch_size: int = 1, | |
crop_to: Optional[Union[int, Tuple[int, int]]] = None, | |
include_path: bool = False) -> DataLoader: | |
transform = transforms.Compose([ | |
ResizePIL(crop_to), | |
transforms.ToTensor(), | |
MinusOneToOne(), | |
]) | |
loader = DataLoader(FoldersDataset(datadir, transform, include_path=include_path), | |
batch_size=batch_size, | |
shuffle=True, num_workers=0, drop_last=False) | |
return loader | |
class FoldersDataset(VisionDataset): | |
def __init__(self, root: str, transforms: Optional[Callable] = None, | |
include_path: bool = False) -> None: | |
super().__init__(root, transforms) | |
self.include_path = include_path | |
self.root = root | |
if os.path.isdir(root): | |
self.fpaths = glob(os.path.join(root, '**', '*.png'), recursive=True) | |
self.fpaths += glob(os.path.join(root, '**', '*.JPEG'), recursive=True) | |
self.fpaths += glob(os.path.join(root, '**', '*.jpg'), recursive=True) | |
self.fpaths = sorted(self.fpaths) | |
assert len(self.fpaths) > 0, "File list is empty. Check the root." | |
elif os.path.exists(root): | |
self.fpaths = [root] | |
else: | |
raise FileNotFoundError(f"File not found: {root}") | |
def __len__(self): | |
return len(self.fpaths) | |
def __getitem__(self, index: int) -> Tuple[torch.Tensor, str]: | |
fpath = self.fpaths[index] | |
img = Image.open(fpath).convert('RGB') | |
if self.transforms is not None: | |
img = self.transforms(img) | |
path = "" | |
if self.include_path: | |
dirname = os.path.dirname(fpath) | |
# remove root from dirname | |
path = dirname[len(self.root) + 1:] | |
return img, os.path.basename(fpath).split(os.extsep)[0], path | |
def compress(model: PipelineWrapper, | |
img_to_compress: torch.Tensor, | |
num_noises: int, | |
loaded_indices, | |
device, | |
): | |
# model.set_timesteps(model.num_timesteps, device=device) | |
dtype = model.dtype | |
prompt_embeds = model.encode_prompt("", None) | |
set_seed(88888888) | |
if img_to_compress is None: | |
img_to_compress = torch.zeros(1, 3, model.get_image_size(), model.get_image_size(), device=device) | |
enc_im = model.encode_image(img_to_compress.to(dtype)) | |
kwargs = model.get_pre_kwargs(height=img_to_compress.shape[-2], width=img_to_compress.shape[-1], | |
prompt_embeds=prompt_embeds) | |
set_seed(100000) | |
xt = torch.randn(1, *enc_im.shape[1:], device=device, dtype=dtype) | |
result_noise_indices = [] | |
pbar = tqdm(model.timesteps) | |
for idx, t in enumerate(pbar): | |
set_seed(idx) | |
noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) | |
_, epst, _ = model.get_epst(xt, t, prompt_embeds, 0.0, **kwargs) | |
x_0_hat = model.get_x_0_hat(xt, epst, t) | |
if loaded_indices is None: | |
if t >= 1: | |
dot_prod = torch.matmul(noise.view(noise.shape[0], -1), | |
(enc_im - x_0_hat).view(enc_im.shape[0], -1).transpose(0, 1)) | |
best_idx = torch.argmax(dot_prod) | |
best_noise = noise[best_idx] | |
else: | |
best_noise = noise[0] | |
else: | |
if t >= 1: | |
best_idx = loaded_indices[idx] | |
best_noise = noise[best_idx] | |
else: | |
best_noise = noise[0] | |
if t >= 1: | |
result_noise_indices.append(best_idx) | |
xt = model.finish_step(xt, x_0_hat, epst, t, best_noise.unsqueeze(0), eta=None) | |
try: | |
img = model.decode_image(xt) | |
except torch.OutOfMemoryError: | |
img = model.decode_image(xt.to('cpu')) | |
return img, torch.tensor(result_noise_indices).squeeze().cpu() | |
def generate_ours(model: PipelineWrapper, | |
num_noises: int, | |
num_noises_to_optimize: int, | |
prompt: str = "", | |
negative_prompt: Optional[str] = None, | |
indices = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
device = model.device | |
dtype = model.dtype | |
# print(num_noises, num_noises_to_optimize, flush=True) | |
# model.set_timesteps(model.num_timesteps, device=device) | |
set_seed(88888888) | |
if prompt is None: | |
prompt = "" | |
prompt_embeds = model.encode_prompt(prompt, negative_prompt) | |
kwargs = model.get_pre_kwargs(height=model.get_image_size(), | |
width=model.get_image_size(), | |
prompt_embeds=prompt_embeds) | |
set_seed(100000) | |
xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype) | |
result_noise_indices = [] | |
pbar = tqdm(model.timesteps) | |
for idx, t in enumerate(pbar): | |
set_seed(idx) | |
noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) # Codebook | |
_, epst_uncond, epst_cond = model.get_epst(xt, t, prompt_embeds, 1.0, return_everything=True, **kwargs) | |
x_0_hat = model.get_x_0_hat(xt, epst_uncond, t) | |
if t >= 1: | |
if indices is None: | |
prev_classif_score = epst_uncond - epst_cond | |
set_seed(int(time.time_ns() & 0xFFFFFFFF)) | |
noise_indices = torch.randint(0, num_noises, size=(num_noises_to_optimize,), device=device) | |
loss = torch.matmul(noise[noise_indices].view(num_noises_to_optimize, -1), | |
prev_classif_score.view(prev_classif_score.shape[0], -1).transpose(0, 1)) | |
best_idx = noise_indices[torch.argmax(loss)] | |
else: | |
best_idx = indices[idx] | |
best_noise = noise[best_idx] | |
result_noise_indices.append(best_idx) | |
else: | |
best_noise = torch.zeros_like(noise[0]) | |
xt = model.finish_step(xt, x_0_hat, epst_uncond, t, best_noise) | |
try: | |
img = model.decode_image(xt) | |
except torch.OutOfMemoryError: | |
img = model.decode_image(xt.to('cpu')) | |
return img, torch.stack(result_noise_indices).squeeze().cpu() | |
def decompress(model: PipelineWrapper, | |
image_size: Tuple[int, int], | |
indices: Dict[str, torch.Tensor], | |
num_noises: int, | |
prompt: str = "", | |
negative_prompt: Optional[str] = None, | |
tedit: int = 0, | |
new_prompt: str = "", | |
new_negative_prompt: Optional[str] = None, | |
guidance_scale: float = 3.0, | |
num_pursuit_noises: Optional[int] = 1, | |
num_pursuit_coef_bits: Optional[int] = 3, | |
t_range: Tuple[int, int] = (999, 0), | |
robust_randn: bool = False | |
) -> torch.Tensor: | |
noise_indices = indices['noise_indices'] | |
coeffs_indices = indices['coeff_indices'] | |
num_pursuit_noises = num_pursuit_noises if num_pursuit_noises is not None else 1 | |
num_pursuit_coef_bits = num_pursuit_coef_bits if num_pursuit_coef_bits is not None else 1 | |
device = model.device | |
dtype = model.dtype | |
# model.set_timesteps(model.num_timesteps, device=device) | |
set_seed(88888888) | |
orig_prompt_embeds = model.encode_prompt(prompt, negative_prompt) | |
kwargs_orig = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1], | |
prompt_embeds=orig_prompt_embeds) | |
if new_prompt != prompt or new_negative_prompt != negative_prompt: | |
new_prompt_embeds = model.encode_prompt(new_prompt, new_negative_prompt) | |
kwargs_new = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1], | |
prompt_embeds=new_prompt_embeds) | |
else: | |
new_prompt_embeds = orig_prompt_embeds | |
kwargs_new = kwargs_orig | |
set_seed(100000) | |
xt = torch.randn(1, *model.get_latent_shape(image_size), device=device, dtype=dtype) | |
pbar = tqdm(model.timesteps) | |
for idx, t in enumerate(pbar): | |
set_seed(idx) | |
dont_optimize_t = not (t_range[0] >= t >= t_range[1]) | |
# No intermittent support | |
if robust_randn: | |
noise = get_robust_randn(num_noises if not dont_optimize_t else 1, xt.shape[1:], device, dtype) | |
else: | |
noise = torch.randn(num_noises if not dont_optimize_t else 1, *xt.shape[1:], device=device, dtype=dtype) | |
curr_embs = orig_prompt_embeds if idx < tedit else new_prompt_embeds | |
curr_kwargs = kwargs_orig if idx < tedit else kwargs_new | |
epst = model.get_epst(xt, t, curr_embs, guidance_scale, **curr_kwargs) | |
x_0_hat = model.get_x_0_hat(xt, epst, t) | |
curr_t_noise_indices = noise_indices[idx] | |
best_noise = noise[curr_t_noise_indices[0]] | |
pursuit_coefs = torch.linspace(0, 1, 2 ** num_pursuit_coef_bits + 1)[1:] | |
if num_pursuit_noises > 1: | |
curr_t_coeffs_indices = coeffs_indices[idx] | |
if curr_t_coeffs_indices[0] == -1: | |
continue | |
for pursuit_idx in range(1, num_pursuit_noises): | |
pursuit_coef = pursuit_coefs[curr_t_coeffs_indices[pursuit_idx]] | |
best_noise = best_noise * torch.sqrt(pursuit_coef) + noise[ | |
curr_t_noise_indices[pursuit_idx]] * torch.sqrt(1 - pursuit_coef) | |
best_noise /= best_noise.std() | |
best_noise = best_noise.unsqueeze(0) | |
xt = model.finish_step(xt, x_0_hat, epst, t, best_noise) | |
img = model.decode_image(xt) | |
return img | |
def inf_generate(model: PipelineWrapper, | |
prompt: str = "", | |
negative_prompt: Optional[str] = None, | |
guidance_scale: float = 7.0, | |
record: int = 0, | |
save_root: str = "") -> Tuple[torch.Tensor, torch.Tensor]: | |
device = model.device | |
dtype = model.dtype | |
model.set_timesteps(model.num_timesteps, device=device) | |
prompt_embeds = model.encode_prompt(prompt, negative_prompt) | |
kwargs = model.get_pre_kwargs(height=model.get_image_size(), | |
width=model.get_image_size(), | |
prompt_embeds=prompt_embeds) | |
xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype) | |
pbar = tqdm(model.timesteps) | |
for idx, t in enumerate(pbar): | |
noise = torch.randn(1, *xt.shape[1:], device=device, dtype=dtype) | |
epst = model.get_epst(xt, t, prompt_embeds, guidance_scale, **kwargs) | |
x_0_hat = model.get_x_0_hat(xt, epst, t) | |
xt = model.finish_step(xt, x_0_hat, epst, t, noise) | |
if record and not idx % record: | |
img = model.decode_image(x_0_hat) | |
plt.imsave(os.path.join(save_root, f"progress/x_0_hat_{str(t.item()).zfill(4)}.png"), | |
clear_color(img[0].unsqueeze(0), normalize=False)) | |
try: | |
img = model.decode_image(xt) | |
except torch.OutOfMemoryError: | |
img = model.decode_image(xt.to('cpu')) | |
return img | |