Spaces:
Runtime error
Runtime error
| """ | |
| wild mixture of | |
| https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py | |
| https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py | |
| https://github.com/CompVis/taming-transformers | |
| -- merci | |
| """ | |
| import time, math | |
| from tqdm.auto import trange, tqdm | |
| import torch | |
| from einops import rearrange | |
| from tqdm import tqdm | |
| from ldmlib.modules.distributions.distributions import DiagonalGaussianDistribution | |
| from ldmlib.models.autoencoder import VQModelInterface | |
| import torch.nn as nn | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| from functools import partial | |
| from pytorch_lightning.utilities.distributed import rank_zero_only | |
| from ldmlib.util import exists, default, instantiate_from_config | |
| from ldmlib.modules.diffusionmodules.util import make_beta_schedule | |
| from ldmlib.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like | |
| from ldmlib.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like | |
| from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff | |
| def disabled_train(self): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| class DDPM(pl.LightningModule): | |
| # classic DDPM with Gaussian diffusion, in image space | |
| def __init__(self, | |
| timesteps=1000, | |
| beta_schedule="linear", | |
| ckpt_path=None, | |
| ignore_keys=[], | |
| load_only_unet=False, | |
| monitor="val/loss", | |
| use_ema=True, | |
| first_stage_key="image", | |
| image_size=256, | |
| channels=3, | |
| log_every_t=100, | |
| clip_denoised=True, | |
| linear_start=1e-4, | |
| linear_end=2e-2, | |
| cosine_s=8e-3, | |
| given_betas=None, | |
| original_elbo_weight=0., | |
| v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta | |
| l_simple_weight=1., | |
| conditioning_key=None, | |
| parameterization="eps", # all assuming fixed variance schedules | |
| scheduler_config=None, | |
| use_positional_encodings=False, | |
| ): | |
| super().__init__() | |
| assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' | |
| self.parameterization = parameterization | |
| print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") | |
| self.cond_stage_model = None | |
| self.clip_denoised = clip_denoised | |
| self.log_every_t = log_every_t | |
| self.first_stage_key = first_stage_key | |
| self.image_size = image_size # try conv? | |
| self.channels = channels | |
| self.use_positional_encodings = use_positional_encodings | |
| self.use_scheduler = scheduler_config is not None | |
| if self.use_scheduler: | |
| self.scheduler_config = scheduler_config | |
| self.v_posterior = v_posterior | |
| self.original_elbo_weight = original_elbo_weight | |
| self.l_simple_weight = l_simple_weight | |
| if monitor is not None: | |
| self.monitor = monitor | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) | |
| self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, | |
| linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) | |
| def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, | |
| linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): | |
| if exists(given_betas): | |
| betas = given_betas | |
| else: | |
| betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, | |
| cosine_s=cosine_s) | |
| alphas = 1. - betas | |
| alphas_cumprod = np.cumprod(alphas, axis=0) | |
| timesteps, = betas.shape | |
| self.num_timesteps = int(timesteps) | |
| self.linear_start = linear_start | |
| self.linear_end = linear_end | |
| assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' | |
| to_torch = partial(torch.tensor, dtype=torch.float32) | |
| self.register_buffer('betas', to_torch(betas)) | |
| self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) | |
| class FirstStage(DDPM): | |
| """main class""" | |
| def __init__(self, | |
| first_stage_config, | |
| num_timesteps_cond=None, | |
| cond_stage_key="image", | |
| cond_stage_trainable=False, | |
| concat_mode=True, | |
| cond_stage_forward=None, | |
| conditioning_key=None, | |
| scale_factor=1.0, | |
| scale_by_std=False, | |
| *args, **kwargs): | |
| self.num_timesteps_cond = default(num_timesteps_cond, 1) | |
| self.scale_by_std = scale_by_std | |
| assert self.num_timesteps_cond <= kwargs['timesteps'] | |
| # for backwards compatibility after implementation of DiffusionWrapper | |
| if conditioning_key is None: | |
| conditioning_key = 'concat' if concat_mode else 'crossattn' | |
| ckpt_path = kwargs.pop("ckpt_path", None) | |
| ignore_keys = kwargs.pop("ignore_keys", []) | |
| super().__init__() | |
| self.concat_mode = concat_mode | |
| self.cond_stage_trainable = cond_stage_trainable | |
| self.cond_stage_key = cond_stage_key | |
| try: | |
| self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 | |
| except: | |
| self.num_downs = 0 | |
| if not scale_by_std: | |
| self.scale_factor = scale_factor | |
| self.instantiate_first_stage(first_stage_config) | |
| self.cond_stage_forward = cond_stage_forward | |
| self.clip_denoised = False | |
| self.bbox_tokenizer = None | |
| self.restarted_from_ckpt = False | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys) | |
| self.restarted_from_ckpt = True | |
| def instantiate_first_stage(self, config): | |
| model = instantiate_from_config(config) | |
| self.first_stage_model = model.eval() | |
| self.first_stage_model.train = disabled_train | |
| for param in self.first_stage_model.parameters(): | |
| param.requires_grad = False | |
| def get_first_stage_encoding(self, encoder_posterior): | |
| if isinstance(encoder_posterior, DiagonalGaussianDistribution): | |
| z = encoder_posterior.sample() | |
| elif isinstance(encoder_posterior, torch.Tensor): | |
| z = encoder_posterior | |
| else: | |
| raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") | |
| return self.scale_factor * z | |
| def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): | |
| if predict_cids: | |
| if z.dim() == 4: | |
| z = torch.argmax(z.exp(), dim=1).long() | |
| z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) | |
| z = rearrange(z, 'b h w c -> b c h w').contiguous() | |
| z = 1. / self.scale_factor * z | |
| if hasattr(self, "split_input_params"): | |
| if isinstance(self.first_stage_model, VQModelInterface): | |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
| else: | |
| return self.first_stage_model.decode(z) | |
| else: | |
| if isinstance(self.first_stage_model, VQModelInterface): | |
| return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) | |
| else: | |
| return self.first_stage_model.decode(z) | |
| def encode_first_stage(self, x): | |
| if hasattr(self, "split_input_params"): | |
| if self.split_input_params["patch_distributed_vq"]: | |
| ks = self.split_input_params["ks"] # eg. (128, 128) | |
| stride = self.split_input_params["stride"] # eg. (64, 64) | |
| df = self.split_input_params["vqf"] | |
| self.split_input_params['original_image_size'] = x.shape[-2:] | |
| bs, nc, h, w = x.shape | |
| if ks[0] > h or ks[1] > w: | |
| ks = (min(ks[0], h), min(ks[1], w)) | |
| print("reducing Kernel") | |
| if stride[0] > h or stride[1] > w: | |
| stride = (min(stride[0], h), min(stride[1], w)) | |
| print("reducing stride") | |
| fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) | |
| z = unfold(x) # (bn, nc * prod(**ks), L) | |
| # Reshape to img shape | |
| z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) | |
| output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) | |
| for i in range(z.shape[-1])] | |
| o = torch.stack(output_list, axis=-1) | |
| o = o * weighting | |
| # Reverse reshape to img shape | |
| o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) | |
| # stitch crops together | |
| decoded = fold(o) | |
| decoded = decoded / normalization | |
| return decoded | |
| else: | |
| return self.first_stage_model.encode(x) | |
| else: | |
| return self.first_stage_model.encode(x) | |
| class CondStage(DDPM): | |
| """main class""" | |
| def __init__(self, | |
| cond_stage_config, | |
| num_timesteps_cond=None, | |
| cond_stage_key="image", | |
| cond_stage_trainable=False, | |
| concat_mode=True, | |
| cond_stage_forward=None, | |
| conditioning_key=None, | |
| scale_factor=1.0, | |
| scale_by_std=False, | |
| *args, **kwargs): | |
| self.num_timesteps_cond = default(num_timesteps_cond, 1) | |
| self.scale_by_std = scale_by_std | |
| assert self.num_timesteps_cond <= kwargs['timesteps'] | |
| # for backwards compatibility after implementation of DiffusionWrapper | |
| if conditioning_key is None: | |
| conditioning_key = 'concat' if concat_mode else 'crossattn' | |
| if cond_stage_config == '__is_unconditional__': | |
| conditioning_key = None | |
| ckpt_path = kwargs.pop("ckpt_path", None) | |
| ignore_keys = kwargs.pop("ignore_keys", []) | |
| super().__init__() | |
| self.concat_mode = concat_mode | |
| self.cond_stage_trainable = cond_stage_trainable | |
| self.cond_stage_key = cond_stage_key | |
| self.num_downs = 0 | |
| if not scale_by_std: | |
| self.scale_factor = scale_factor | |
| self.instantiate_cond_stage(cond_stage_config) | |
| self.cond_stage_forward = cond_stage_forward | |
| self.clip_denoised = False | |
| self.bbox_tokenizer = None | |
| self.restarted_from_ckpt = False | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys) | |
| self.restarted_from_ckpt = True | |
| def instantiate_cond_stage(self, config): | |
| if not self.cond_stage_trainable: | |
| if config == "__is_first_stage__": | |
| print("Using first stage also as cond stage.") | |
| self.cond_stage_model = self.first_stage_model | |
| elif config == "__is_unconditional__": | |
| print(f"Training {self.__class__.__name__} as an unconditional model.") | |
| self.cond_stage_model = None | |
| # self.be_unconditional = True | |
| else: | |
| model = instantiate_from_config(config) | |
| self.cond_stage_model = model.eval() | |
| self.cond_stage_model.train = disabled_train | |
| for param in self.cond_stage_model.parameters(): | |
| param.requires_grad = False | |
| else: | |
| assert config != '__is_first_stage__' | |
| assert config != '__is_unconditional__' | |
| model = instantiate_from_config(config) | |
| self.cond_stage_model = model | |
| def get_learned_conditioning(self, c): | |
| if self.cond_stage_forward is None: | |
| if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): | |
| c = self.cond_stage_model.encode(c) | |
| if isinstance(c, DiagonalGaussianDistribution): | |
| c = c.mode() | |
| else: | |
| c = self.cond_stage_model(c) | |
| else: | |
| assert hasattr(self.cond_stage_model, self.cond_stage_forward) | |
| c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) | |
| return c | |
| class DiffusionWrapper(pl.LightningModule): | |
| def __init__(self, diff_model_config): | |
| super().__init__() | |
| self.diffusion_model = instantiate_from_config(diff_model_config) | |
| def forward(self, x, t, cc): | |
| out = self.diffusion_model(x, t, context=cc) | |
| return out | |
| class DiffusionWrapperOut(pl.LightningModule): | |
| def __init__(self, diff_model_config): | |
| super().__init__() | |
| self.diffusion_model = instantiate_from_config(diff_model_config) | |
| def forward(self, h,emb,tp,hs, cc): | |
| return self.diffusion_model(h,emb,tp,hs, context=cc) | |
| class UNet(DDPM): | |
| """main class""" | |
| def __init__(self, | |
| unetConfigEncode, | |
| unetConfigDecode, | |
| num_timesteps_cond=None, | |
| cond_stage_key="image", | |
| cond_stage_trainable=False, | |
| concat_mode=True, | |
| cond_stage_forward=None, | |
| conditioning_key=None, | |
| scale_factor=1.0, | |
| unet_bs = 1, | |
| scale_by_std=False, | |
| *args, **kwargs): | |
| self.num_timesteps_cond = default(num_timesteps_cond, 1) | |
| self.scale_by_std = scale_by_std | |
| assert self.num_timesteps_cond <= kwargs['timesteps'] | |
| # for backwards compatibility after implementation of DiffusionWrapper | |
| if conditioning_key is None: | |
| conditioning_key = 'concat' if concat_mode else 'crossattn' | |
| ckpt_path = kwargs.pop("ckpt_path", None) | |
| ignore_keys = kwargs.pop("ignore_keys", []) | |
| super().__init__(conditioning_key=conditioning_key, *args, **kwargs) | |
| self.concat_mode = concat_mode | |
| self.cond_stage_trainable = cond_stage_trainable | |
| self.cond_stage_key = cond_stage_key | |
| self.num_downs = 0 | |
| self.cdevice = "cuda" | |
| self.unetConfigEncode = unetConfigEncode | |
| self.unetConfigDecode = unetConfigDecode | |
| if not scale_by_std: | |
| self.scale_factor = scale_factor | |
| else: | |
| self.register_buffer('scale_factor', torch.tensor(scale_factor)) | |
| self.cond_stage_forward = cond_stage_forward | |
| self.clip_denoised = False | |
| self.bbox_tokenizer = None | |
| self.model1 = DiffusionWrapper(self.unetConfigEncode) | |
| self.model2 = DiffusionWrapperOut(self.unetConfigDecode) | |
| self.model1.eval() | |
| self.model2.eval() | |
| self.turbo = False | |
| self.unet_bs = unet_bs | |
| self.restarted_from_ckpt = False | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys) | |
| self.restarted_from_ckpt = True | |
| def make_cond_schedule(self, ): | |
| self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) | |
| ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() | |
| self.cond_ids[:self.num_timesteps_cond] = ids | |
| def on_train_batch_start(self, batch, batch_idx): | |
| # only for very first batch | |
| if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: | |
| assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' | |
| # set rescale weight to 1./std of encodings | |
| print("### USING STD-RESCALING ###") | |
| x = super().get_input(batch, self.first_stage_key) | |
| x = x.to(self.cdevice) | |
| encoder_posterior = self.encode_first_stage(x) | |
| z = self.get_first_stage_encoding(encoder_posterior).detach() | |
| del self.scale_factor | |
| self.register_buffer('scale_factor', 1. / z.flatten().std()) | |
| print(f"setting self.scale_factor to {self.scale_factor}") | |
| print("### USING STD-RESCALING ###") | |
| def apply_model(self, x_noisy, t, cond, return_ids=False): | |
| if(not self.turbo): | |
| self.model1.to(self.cdevice) | |
| step = self.unet_bs | |
| h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step]) | |
| bs = cond.shape[0] | |
| # assert bs%2 == 0 | |
| lenhs = len(hs) | |
| for i in range(step,bs,step): | |
| h_temp,emb_temp,hs_temp = self.model1(x_noisy[i:i+step], t[i:i+step], cond[i:i+step]) | |
| h = torch.cat((h,h_temp)) | |
| emb = torch.cat((emb,emb_temp)) | |
| for j in range(lenhs): | |
| hs[j] = torch.cat((hs[j], hs_temp[j])) | |
| if(not self.turbo): | |
| self.model1.to("cpu") | |
| self.model2.to(self.cdevice) | |
| hs_temp = [hs[j][:step] for j in range(lenhs)] | |
| x_recon = self.model2(h[:step],emb[:step],x_noisy.dtype,hs_temp,cond[:step]) | |
| for i in range(step,bs,step): | |
| hs_temp = [hs[j][i:i+step] for j in range(lenhs)] | |
| x_recon1 = self.model2(h[i:i+step],emb[i:i+step],x_noisy.dtype,hs_temp,cond[i:i+step]) | |
| x_recon = torch.cat((x_recon, x_recon1)) | |
| if(not self.turbo): | |
| self.model2.to("cpu") | |
| if isinstance(x_recon, tuple) and not return_ids: | |
| return x_recon[0] | |
| else: | |
| return x_recon | |
| def register_buffer1(self, name, attr): | |
| if type(attr) == torch.Tensor: | |
| if attr.device != torch.device(self.cdevice): | |
| attr = attr.to(torch.device(self.cdevice)) | |
| setattr(self, name, attr) | |
| def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): | |
| self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, | |
| num_ddpm_timesteps=self.num_timesteps,verbose=verbose) | |
| assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' | |
| to_torch = lambda x: x.to(self.cdevice) | |
| self.register_buffer1('betas', to_torch(self.betas)) | |
| self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod)) | |
| # ddim sampling parameters | |
| ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(), | |
| ddim_timesteps=self.ddim_timesteps, | |
| eta=ddim_eta,verbose=verbose) | |
| self.register_buffer1('ddim_sigmas', ddim_sigmas) | |
| self.register_buffer1('ddim_alphas', ddim_alphas) | |
| self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev) | |
| self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) | |
| def sample(self, | |
| S, | |
| conditioning, | |
| x0=None, | |
| shape = None, | |
| seed=1234, | |
| callback=None, | |
| img_callback=None, | |
| quantize_x0=False, | |
| eta=0., | |
| mask=None, | |
| sampler = "plms", | |
| temperature=1., | |
| noise_dropout=0., | |
| score_corrector=None, | |
| corrector_kwargs=None, | |
| verbose=True, | |
| x_T=None, | |
| log_every_t=100, | |
| unconditional_guidance_scale=1., | |
| unconditional_conditioning=None, | |
| ): | |
| if(self.turbo): | |
| self.model1.to(self.cdevice) | |
| self.model2.to(self.cdevice) | |
| if x0 is None: | |
| batch_size, b1, b2, b3 = shape | |
| img_shape = (1, b1, b2, b3) | |
| tens = [] | |
| print("seeds used = ", [seed+s for s in range(batch_size)]) | |
| for _ in range(batch_size): | |
| torch.manual_seed(seed) | |
| tens.append(torch.randn(img_shape, device=self.cdevice)) | |
| seed+=1 | |
| noise = torch.cat(tens) | |
| del tens | |
| x_latent = noise if x0 is None else x0 | |
| # sampling | |
| if sampler in ('ddim', 'dpm2', 'heun', 'dpm2_a', 'lms') and not hasattr(self, 'ddim_timesteps'): | |
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) | |
| if sampler == "plms": | |
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) | |
| print(f'Data shape for PLMS sampling is {shape}') | |
| samples = self.plms_sampling(conditioning, batch_size, x_latent, | |
| callback=callback, | |
| img_callback=img_callback, | |
| quantize_denoised=quantize_x0, | |
| mask=mask, x0=x0, | |
| ddim_use_original_steps=False, | |
| noise_dropout=noise_dropout, | |
| temperature=temperature, | |
| score_corrector=score_corrector, | |
| corrector_kwargs=corrector_kwargs, | |
| log_every_t=log_every_t, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_conditioning=unconditional_conditioning, | |
| ) | |
| elif sampler == "ddim": | |
| samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_conditioning=unconditional_conditioning, | |
| mask = mask,init_latent=x_T,use_original_steps=False, | |
| callback=callback, img_callback=img_callback) | |
| elif sampler == "euler": | |
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) | |
| samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| img_callback=img_callback) | |
| elif sampler == "euler_a": | |
| self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) | |
| samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| img_callback=img_callback) | |
| elif sampler == "dpm2": | |
| samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| img_callback=img_callback) | |
| elif sampler == "heun": | |
| samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| img_callback=img_callback) | |
| elif sampler == "dpm2_a": | |
| samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| img_callback=img_callback) | |
| elif sampler == "lms": | |
| samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| img_callback=img_callback) | |
| yield from samples | |
| if(self.turbo): | |
| self.model1.to("cpu") | |
| self.model2.to("cpu") | |
| def plms_sampling(self, cond,b, img, | |
| ddim_use_original_steps=False, | |
| callback=None, quantize_denoised=False, | |
| mask=None, x0=None, img_callback=None, log_every_t=100, | |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
| unconditional_guidance_scale=1., unconditional_conditioning=None,): | |
| device = self.betas.device | |
| timesteps = self.ddim_timesteps | |
| time_range = np.flip(timesteps) | |
| total_steps = timesteps.shape[0] | |
| print(f"Running PLMS Sampling with {total_steps} timesteps") | |
| iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) | |
| old_eps = [] | |
| for i, step in enumerate(iterator): | |
| index = total_steps - i - 1 | |
| ts = torch.full((b,), step, device=device, dtype=torch.long) | |
| ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) | |
| if mask is not None: | |
| assert x0 is not None | |
| img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass? | |
| img = img_orig * mask + (1. - mask) * img | |
| outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, | |
| quantize_denoised=quantize_denoised, temperature=temperature, | |
| noise_dropout=noise_dropout, score_corrector=score_corrector, | |
| corrector_kwargs=corrector_kwargs, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_conditioning=unconditional_conditioning, | |
| old_eps=old_eps, t_next=ts_next) | |
| img, pred_x0, e_t = outs | |
| old_eps.append(e_t) | |
| if len(old_eps) >= 4: | |
| old_eps.pop(0) | |
| if callback: yield from callback(i) | |
| if img_callback: yield from img_callback(pred_x0, i) | |
| yield from img_callback(img, len(iterator)-1) | |
| def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, | |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
| unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): | |
| b, *_, device = *x.shape, x.device | |
| def get_model_output(x, t): | |
| if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | |
| e_t = self.apply_model(x, t, c) | |
| else: | |
| x_in = torch.cat([x] * 2) | |
| t_in = torch.cat([t] * 2) | |
| c_in = torch.cat([unconditional_conditioning, c]) | |
| e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2) | |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| if score_corrector is not None: | |
| assert self.parameterization == "eps" | |
| e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) | |
| return e_t | |
| alphas = self.ddim_alphas | |
| alphas_prev = self.ddim_alphas_prev | |
| sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas | |
| sigmas = self.ddim_sigmas | |
| def get_x_prev_and_pred_x0(e_t, index): | |
| # select parameters corresponding to the currently considered timestep | |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) | |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) | |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) | |
| sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) | |
| # current prediction for x_0 | |
| pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
| if quantize_denoised: | |
| pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0) | |
| # direction pointing to x_t | |
| dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t | |
| noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature | |
| if noise_dropout > 0.: | |
| noise = torch.nn.functional.dropout(noise, p=noise_dropout) | |
| x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise | |
| return x_prev, pred_x0 | |
| e_t = get_model_output(x, t) | |
| if len(old_eps) == 0: | |
| # Pseudo Improved Euler (2nd order) | |
| x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) | |
| e_t_next = get_model_output(x_prev, t_next) | |
| e_t_prime = (e_t + e_t_next) / 2 | |
| elif len(old_eps) == 1: | |
| # 2nd order Pseudo Linear Multistep (Adams-Bashforth) | |
| e_t_prime = (3 * e_t - old_eps[-1]) / 2 | |
| elif len(old_eps) == 2: | |
| # 3nd order Pseudo Linear Multistep (Adams-Bashforth) | |
| e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 | |
| elif len(old_eps) >= 3: | |
| # 4nd order Pseudo Linear Multistep (Adams-Bashforth) | |
| e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 | |
| x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) | |
| return x_prev, pred_x0, e_t | |
| def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None): | |
| # fast, but does not allow for exact reconstruction | |
| # t serves as an index to gather the correct alphas | |
| self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) | |
| sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) | |
| if noise is None: | |
| b0, b1, b2, b3 = x0.shape | |
| img_shape = (1, b1, b2, b3) | |
| tens = [] | |
| print("seeds used = ", [seed+s for s in range(b0)]) | |
| for _ in range(b0): | |
| torch.manual_seed(seed) | |
| tens.append(torch.randn(img_shape, device=x0.device)) | |
| seed+=1 | |
| noise = torch.cat(tens) | |
| del tens | |
| return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + | |
| extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise) | |
| def add_noise(self, x0, t): | |
| sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) | |
| noise = torch.randn(x0.shape, device=x0.device) | |
| # print(extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape), | |
| # extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape)) | |
| return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + | |
| extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise) | |
| def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, | |
| mask = None,init_latent=None,use_original_steps=False, | |
| callback=None, img_callback=None): | |
| timesteps = self.ddim_timesteps | |
| timesteps = timesteps[:t_start] | |
| time_range = np.flip(timesteps) | |
| total_steps = timesteps.shape[0] | |
| print(f"Running DDIM Sampling with {total_steps} timesteps") | |
| iterator = tqdm(time_range, desc='Decoding image', total=total_steps) | |
| x_dec = x_latent | |
| x0 = init_latent | |
| for i, step in enumerate(iterator): | |
| index = total_steps - i - 1 | |
| ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) | |
| if mask is not None: | |
| # x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice)) | |
| x0_noisy = x0 | |
| x_dec = x0_noisy* mask + (1. - mask) * x_dec | |
| x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, | |
| unconditional_guidance_scale=unconditional_guidance_scale, | |
| unconditional_conditioning=unconditional_conditioning) | |
| if callback: yield from callback(i) | |
| if img_callback: yield from img_callback(x_dec, i) | |
| if mask is not None: | |
| x_dec = x0 * mask + (1. - mask) * x_dec | |
| yield from img_callback(x_dec, len(iterator)-1) | |
| def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, | |
| temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
| unconditional_guidance_scale=1., unconditional_conditioning=None): | |
| b, *_, device = *x.shape, x.device | |
| if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | |
| e_t = self.apply_model(x, t, c) | |
| else: | |
| x_in = torch.cat([x] * 2) | |
| t_in = torch.cat([t] * 2) | |
| c_in = torch.cat([unconditional_conditioning, c]) | |
| e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2) | |
| e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| if score_corrector is not None: | |
| assert self.model.parameterization == "eps" | |
| e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) | |
| alphas = self.ddim_alphas | |
| alphas_prev = self.ddim_alphas_prev | |
| sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas | |
| sigmas = self.ddim_sigmas | |
| # select parameters corresponding to the currently considered timestep | |
| a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) | |
| a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) | |
| sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) | |
| sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) | |
| # current prediction for x_0 | |
| pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
| if quantize_denoised: | |
| pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0) | |
| # direction pointing to x_t | |
| dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t | |
| noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature | |
| if noise_dropout > 0.: | |
| noise = torch.nn.functional.dropout(noise, p=noise_dropout) | |
| x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise | |
| return x_prev | |
| def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., | |
| img_callback=None): | |
| """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" | |
| extra_args = {} if extra_args is None else extra_args | |
| cvd = CompVisDenoiser(ac) | |
| sigmas = cvd.get_sigmas(S) | |
| x = x*sigmas[0] | |
| print(f"Running Euler Sampling with {len(sigmas) - 1} timesteps") | |
| s_in = x.new_ones([x.shape[0]]).half() | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. | |
| eps = torch.randn_like(x) * s_noise | |
| sigma_hat = (sigmas[i] * (gamma + 1)).half() | |
| if gamma > 0: | |
| x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 | |
| s_i = sigma_hat * s_in | |
| x_in = torch.cat([x] * 2) | |
| t_in = torch.cat([s_i] * 2) | |
| cond_in = torch.cat([unconditional_conditioning, cond]) | |
| c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
| eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
| e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
| denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| d = to_d(x, sigma_hat, denoised) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) | |
| if img_callback: yield from img_callback(x, i) | |
| dt = sigmas[i + 1] - sigma_hat | |
| # Euler method | |
| x = x + d * dt | |
| yield from img_callback(x, len(sigmas)-1) | |
| def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, | |
| img_callback=None): | |
| """Ancestral sampling with Euler method steps.""" | |
| extra_args = {} if extra_args is None else extra_args | |
| cvd = CompVisDenoiser(ac) | |
| sigmas = cvd.get_sigmas(S) | |
| x = x*sigmas[0] | |
| print(f"Running Euler Ancestral Sampling with {len(sigmas) - 1} timesteps") | |
| s_in = x.new_ones([x.shape[0]]).half() | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| s_i = sigmas[i] * s_in | |
| x_in = torch.cat([x] * 2) | |
| t_in = torch.cat([s_i] * 2) | |
| cond_in = torch.cat([unconditional_conditioning, cond]) | |
| c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
| eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
| e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
| denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
| if img_callback: yield from img_callback(x, i) | |
| d = to_d(x, sigmas[i], denoised) | |
| # Euler method | |
| dt = sigma_down - sigmas[i] | |
| x = x + d * dt | |
| x = x + torch.randn_like(x) * sigma_up | |
| yield from img_callback(x, len(sigmas)-1) | |
| def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., | |
| img_callback=None): | |
| """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" | |
| extra_args = {} if extra_args is None else extra_args | |
| cvd = CompVisDenoiser(alphas_cumprod=ac) | |
| sigmas = cvd.get_sigmas(S) | |
| x = x*sigmas[0] | |
| print(f"Running Heun Sampling with {len(sigmas) - 1} timesteps") | |
| s_in = x.new_ones([x.shape[0]]).half() | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. | |
| eps = torch.randn_like(x) * s_noise | |
| sigma_hat = (sigmas[i] * (gamma + 1)).half() | |
| if gamma > 0: | |
| x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 | |
| s_i = sigma_hat * s_in | |
| x_in = torch.cat([x] * 2) | |
| t_in = torch.cat([s_i] * 2) | |
| cond_in = torch.cat([unconditional_conditioning, cond]) | |
| c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
| eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
| e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
| denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| d = to_d(x, sigma_hat, denoised) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) | |
| if img_callback: yield from img_callback(x, i) | |
| dt = sigmas[i + 1] - sigma_hat | |
| if sigmas[i + 1] == 0: | |
| # Euler method | |
| x = x + d * dt | |
| else: | |
| # Heun's method | |
| x_2 = x + d * dt | |
| s_i = sigmas[i + 1] * s_in | |
| x_in = torch.cat([x_2] * 2) | |
| t_in = torch.cat([s_i] * 2) | |
| cond_in = torch.cat([unconditional_conditioning, cond]) | |
| c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
| eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
| e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
| denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| d_2 = to_d(x_2, sigmas[i + 1], denoised_2) | |
| d_prime = (d + d_2) / 2 | |
| x = x + d_prime * dt | |
| yield from img_callback(x, len(sigmas)-1) | |
| def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., | |
| img_callback=None): | |
| """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" | |
| extra_args = {} if extra_args is None else extra_args | |
| cvd = CompVisDenoiser(ac) | |
| sigmas = cvd.get_sigmas(S) | |
| x = x*sigmas[0] | |
| print(f"Running DPM2 Sampling with {len(sigmas) - 1} timesteps") | |
| s_in = x.new_ones([x.shape[0]]).half() | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. | |
| eps = torch.randn_like(x) * s_noise | |
| sigma_hat = sigmas[i] * (gamma + 1) | |
| if gamma > 0: | |
| x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 | |
| s_i = sigma_hat * s_in | |
| x_in = torch.cat([x] * 2) | |
| t_in = torch.cat([s_i] * 2) | |
| cond_in = torch.cat([unconditional_conditioning, cond]) | |
| c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
| eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
| e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
| denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| if img_callback: yield from img_callback(x, i) | |
| d = to_d(x, sigma_hat, denoised) | |
| # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule | |
| sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 | |
| dt_1 = sigma_mid - sigma_hat | |
| dt_2 = sigmas[i + 1] - sigma_hat | |
| x_2 = x + d * dt_1 | |
| s_i = sigma_mid * s_in | |
| x_in = torch.cat([x_2] * 2) | |
| t_in = torch.cat([s_i] * 2) | |
| cond_in = torch.cat([unconditional_conditioning, cond]) | |
| c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
| eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
| e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
| denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| d_2 = to_d(x_2, sigma_mid, denoised_2) | |
| x = x + d_2 * dt_2 | |
| yield from img_callback(x, len(sigmas)-1) | |
| def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, | |
| img_callback=None): | |
| """Ancestral sampling with DPM-Solver inspired second-order steps.""" | |
| extra_args = {} if extra_args is None else extra_args | |
| cvd = CompVisDenoiser(ac) | |
| sigmas = cvd.get_sigmas(S) | |
| x = x*sigmas[0] | |
| print(f"Running DPM2 Ancestral Sampling with {len(sigmas) - 1} timesteps") | |
| s_in = x.new_ones([x.shape[0]]).half() | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| s_i = sigmas[i] * s_in | |
| x_in = torch.cat([x] * 2) | |
| t_in = torch.cat([s_i] * 2) | |
| cond_in = torch.cat([unconditional_conditioning, cond]) | |
| c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
| eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
| e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
| denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
| if img_callback: yield from img_callback(x, i) | |
| d = to_d(x, sigmas[i], denoised) | |
| # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule | |
| sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 | |
| dt_1 = sigma_mid - sigmas[i] | |
| dt_2 = sigma_down - sigmas[i] | |
| x_2 = x + d * dt_1 | |
| s_i = sigma_mid * s_in | |
| x_in = torch.cat([x_2] * 2) | |
| t_in = torch.cat([s_i] * 2) | |
| cond_in = torch.cat([unconditional_conditioning, cond]) | |
| c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
| eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
| e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
| denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| d_2 = to_d(x_2, sigma_mid, denoised_2) | |
| x = x + d_2 * dt_2 | |
| x = x + torch.randn_like(x) * sigma_up | |
| yield from img_callback(x, len(sigmas)-1) | |
| def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4, | |
| img_callback=None): | |
| extra_args = {} if extra_args is None else extra_args | |
| s_in = x.new_ones([x.shape[0]]) | |
| cvd = CompVisDenoiser(ac) | |
| sigmas = cvd.get_sigmas(S) | |
| x = x*sigmas[0] | |
| print(f"Running LMS Sampling with {len(sigmas) - 1} timesteps") | |
| ds = [] | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| s_i = sigmas[i] * s_in | |
| x_in = torch.cat([x] * 2) | |
| t_in = torch.cat([s_i] * 2) | |
| cond_in = torch.cat([unconditional_conditioning, cond]) | |
| c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] | |
| eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) | |
| e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) | |
| denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
| if img_callback: yield from img_callback(x, i) | |
| d = to_d(x, sigmas[i], denoised) | |
| ds.append(d) | |
| if len(ds) > order: | |
| ds.pop(0) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
| cur_order = min(i + 1, order) | |
| coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] | |
| x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) | |
| yield from img_callback(x, len(sigmas)-1) | |