DDCM's picture
initial commit
b273838
import os
import spaces
import time
from glob import glob
from typing import Callable, Optional, Tuple, Union, Dict
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset
from tqdm import tqdm
from util.img_utils import clear_color
from latent_models import PipelineWrapper
def set_seed(seed: int) -> None:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
class MinusOneToOne(torch.nn.Module):
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor * 2 - 1
class ResizePIL(torch.nn.Module):
def __init__(self, image_size: Optional[Union[int, Tuple[int, int]]] = None):
super().__init__()
if isinstance(image_size, int):
image_size = (image_size, image_size)
self.image_size = image_size
def forward(self, pil_image: Image.Image) -> Image.Image:
if self.image_size is not None:
pil_image = pil_image.resize(self.image_size)
return pil_image
def get_loader(datadir: str, batch_size: int = 1,
crop_to: Optional[Union[int, Tuple[int, int]]] = None,
include_path: bool = False) -> DataLoader:
transform = transforms.Compose([
ResizePIL(crop_to),
transforms.ToTensor(),
MinusOneToOne(),
])
loader = DataLoader(FoldersDataset(datadir, transform, include_path=include_path),
batch_size=batch_size,
shuffle=True, num_workers=0, drop_last=False)
return loader
class FoldersDataset(VisionDataset):
def __init__(self, root: str, transforms: Optional[Callable] = None,
include_path: bool = False) -> None:
super().__init__(root, transforms)
self.include_path = include_path
self.root = root
if os.path.isdir(root):
self.fpaths = glob(os.path.join(root, '**', '*.png'), recursive=True)
self.fpaths += glob(os.path.join(root, '**', '*.JPEG'), recursive=True)
self.fpaths += glob(os.path.join(root, '**', '*.jpg'), recursive=True)
self.fpaths = sorted(self.fpaths)
assert len(self.fpaths) > 0, "File list is empty. Check the root."
elif os.path.exists(root):
self.fpaths = [root]
else:
raise FileNotFoundError(f"File not found: {root}")
def __len__(self):
return len(self.fpaths)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, str]:
fpath = self.fpaths[index]
img = Image.open(fpath).convert('RGB')
if self.transforms is not None:
img = self.transforms(img)
path = ""
if self.include_path:
dirname = os.path.dirname(fpath)
# remove root from dirname
path = dirname[len(self.root) + 1:]
return img, os.path.basename(fpath).split(os.extsep)[0], path
@spaces.GPU
def compress(model: PipelineWrapper,
img_to_compress: torch.Tensor,
num_noises: int,
loaded_indices,
device,
):
# model.set_timesteps(model.num_timesteps, device=device)
dtype = model.dtype
prompt_embeds = model.encode_prompt("", None)
set_seed(88888888)
if img_to_compress is None:
img_to_compress = torch.zeros(1, 3, model.get_image_size(), model.get_image_size(), device=device)
enc_im = model.encode_image(img_to_compress.to(dtype))
kwargs = model.get_pre_kwargs(height=img_to_compress.shape[-2], width=img_to_compress.shape[-1],
prompt_embeds=prompt_embeds)
set_seed(100000)
xt = torch.randn(1, *enc_im.shape[1:], device=device, dtype=dtype)
result_noise_indices = []
pbar = tqdm(model.timesteps)
for idx, t in enumerate(pbar):
set_seed(idx)
noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype)
_, epst, _ = model.get_epst(xt, t, prompt_embeds, 0.0, **kwargs)
x_0_hat = model.get_x_0_hat(xt, epst, t)
if loaded_indices is None:
if t >= 1:
dot_prod = torch.matmul(noise.view(noise.shape[0], -1),
(enc_im - x_0_hat).view(enc_im.shape[0], -1).transpose(0, 1))
best_idx = torch.argmax(dot_prod)
best_noise = noise[best_idx]
else:
best_noise = noise[0]
else:
if t >= 1:
best_idx = loaded_indices[idx]
best_noise = noise[best_idx]
else:
best_noise = noise[0]
if t >= 1:
result_noise_indices.append(best_idx)
xt = model.finish_step(xt, x_0_hat, epst, t, best_noise.unsqueeze(0), eta=None)
try:
img = model.decode_image(xt)
except torch.OutOfMemoryError:
img = model.decode_image(xt.to('cpu'))
return img, torch.tensor(result_noise_indices).squeeze().cpu()
@spaces.GPU
def generate_ours(model: PipelineWrapper,
num_noises: int,
num_noises_to_optimize: int,
prompt: str = "",
negative_prompt: Optional[str] = None,
indices = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = model.device
dtype = model.dtype
# print(num_noises, num_noises_to_optimize, flush=True)
# model.set_timesteps(model.num_timesteps, device=device)
set_seed(88888888)
if prompt is None:
prompt = ""
prompt_embeds = model.encode_prompt(prompt, negative_prompt)
kwargs = model.get_pre_kwargs(height=model.get_image_size(),
width=model.get_image_size(),
prompt_embeds=prompt_embeds)
set_seed(100000)
xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype)
result_noise_indices = []
pbar = tqdm(model.timesteps)
for idx, t in enumerate(pbar):
set_seed(idx)
noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) # Codebook
_, epst_uncond, epst_cond = model.get_epst(xt, t, prompt_embeds, 1.0, return_everything=True, **kwargs)
x_0_hat = model.get_x_0_hat(xt, epst_uncond, t)
if t >= 1:
if indices is None:
prev_classif_score = epst_uncond - epst_cond
set_seed(int(time.time_ns() & 0xFFFFFFFF))
noise_indices = torch.randint(0, num_noises, size=(num_noises_to_optimize,), device=device)
loss = torch.matmul(noise[noise_indices].view(num_noises_to_optimize, -1),
prev_classif_score.view(prev_classif_score.shape[0], -1).transpose(0, 1))
best_idx = noise_indices[torch.argmax(loss)]
else:
best_idx = indices[idx]
best_noise = noise[best_idx]
result_noise_indices.append(best_idx)
else:
best_noise = torch.zeros_like(noise[0])
xt = model.finish_step(xt, x_0_hat, epst_uncond, t, best_noise)
try:
img = model.decode_image(xt)
except torch.OutOfMemoryError:
img = model.decode_image(xt.to('cpu'))
return img, torch.stack(result_noise_indices).squeeze().cpu()
def decompress(model: PipelineWrapper,
image_size: Tuple[int, int],
indices: Dict[str, torch.Tensor],
num_noises: int,
prompt: str = "",
negative_prompt: Optional[str] = None,
tedit: int = 0,
new_prompt: str = "",
new_negative_prompt: Optional[str] = None,
guidance_scale: float = 3.0,
num_pursuit_noises: Optional[int] = 1,
num_pursuit_coef_bits: Optional[int] = 3,
t_range: Tuple[int, int] = (999, 0),
robust_randn: bool = False
) -> torch.Tensor:
noise_indices = indices['noise_indices']
coeffs_indices = indices['coeff_indices']
num_pursuit_noises = num_pursuit_noises if num_pursuit_noises is not None else 1
num_pursuit_coef_bits = num_pursuit_coef_bits if num_pursuit_coef_bits is not None else 1
device = model.device
dtype = model.dtype
# model.set_timesteps(model.num_timesteps, device=device)
set_seed(88888888)
orig_prompt_embeds = model.encode_prompt(prompt, negative_prompt)
kwargs_orig = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1],
prompt_embeds=orig_prompt_embeds)
if new_prompt != prompt or new_negative_prompt != negative_prompt:
new_prompt_embeds = model.encode_prompt(new_prompt, new_negative_prompt)
kwargs_new = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1],
prompt_embeds=new_prompt_embeds)
else:
new_prompt_embeds = orig_prompt_embeds
kwargs_new = kwargs_orig
set_seed(100000)
xt = torch.randn(1, *model.get_latent_shape(image_size), device=device, dtype=dtype)
pbar = tqdm(model.timesteps)
for idx, t in enumerate(pbar):
set_seed(idx)
dont_optimize_t = not (t_range[0] >= t >= t_range[1])
# No intermittent support
if robust_randn:
noise = get_robust_randn(num_noises if not dont_optimize_t else 1, xt.shape[1:], device, dtype)
else:
noise = torch.randn(num_noises if not dont_optimize_t else 1, *xt.shape[1:], device=device, dtype=dtype)
curr_embs = orig_prompt_embeds if idx < tedit else new_prompt_embeds
curr_kwargs = kwargs_orig if idx < tedit else kwargs_new
epst = model.get_epst(xt, t, curr_embs, guidance_scale, **curr_kwargs)
x_0_hat = model.get_x_0_hat(xt, epst, t)
curr_t_noise_indices = noise_indices[idx]
best_noise = noise[curr_t_noise_indices[0]]
pursuit_coefs = torch.linspace(0, 1, 2 ** num_pursuit_coef_bits + 1)[1:]
if num_pursuit_noises > 1:
curr_t_coeffs_indices = coeffs_indices[idx]
if curr_t_coeffs_indices[0] == -1:
continue
for pursuit_idx in range(1, num_pursuit_noises):
pursuit_coef = pursuit_coefs[curr_t_coeffs_indices[pursuit_idx]]
best_noise = best_noise * torch.sqrt(pursuit_coef) + noise[
curr_t_noise_indices[pursuit_idx]] * torch.sqrt(1 - pursuit_coef)
best_noise /= best_noise.std()
best_noise = best_noise.unsqueeze(0)
xt = model.finish_step(xt, x_0_hat, epst, t, best_noise)
img = model.decode_image(xt)
return img
def inf_generate(model: PipelineWrapper,
prompt: str = "",
negative_prompt: Optional[str] = None,
guidance_scale: float = 7.0,
record: int = 0,
save_root: str = "") -> Tuple[torch.Tensor, torch.Tensor]:
device = model.device
dtype = model.dtype
model.set_timesteps(model.num_timesteps, device=device)
prompt_embeds = model.encode_prompt(prompt, negative_prompt)
kwargs = model.get_pre_kwargs(height=model.get_image_size(),
width=model.get_image_size(),
prompt_embeds=prompt_embeds)
xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype)
pbar = tqdm(model.timesteps)
for idx, t in enumerate(pbar):
noise = torch.randn(1, *xt.shape[1:], device=device, dtype=dtype)
epst = model.get_epst(xt, t, prompt_embeds, guidance_scale, **kwargs)
x_0_hat = model.get_x_0_hat(xt, epst, t)
xt = model.finish_step(xt, x_0_hat, epst, t, noise)
if record and not idx % record:
img = model.decode_image(x_0_hat)
plt.imsave(os.path.join(save_root, f"progress/x_0_hat_{str(t.item()).zfill(4)}.png"),
clear_color(img[0].unsqueeze(0), normalize=False))
try:
img = model.decode_image(xt)
except torch.OutOfMemoryError:
img = model.decode_image(xt.to('cpu'))
return img