import torch from diffusers import DDIMScheduler, StableDiffusionPipeline from typing import Optional, Tuple, Union from transformers import Blip2Processor, Blip2ForConditionalGeneration class PipelineWrapper(torch.nn.Module): def __init__(self, model_id: str, timesteps: int, device: torch.device, float16: bool = False, compile: bool = True, token: Optional[str] = None, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model_id = model_id self.num_timesteps = timesteps self.device = device self.float16 = float16 self.token = token self.compile = compile self.model = None # def get_sigma(self, timestep: int) -> float: # sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1) # return sqrt_recipm1_alphas_cumprod[timestep] @property def timesteps(self) -> torch.Tensor: return self.model.scheduler.timesteps @property def dtype(self) -> torch.dtype: if self.model is None: raise AttributeError("Model is not initialized.") return self.model.unet.dtype def get_x_0_hat(self, xt: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: return self.model.scheduler.get_x_0_hat(xt, epst, timestep) def finish_step(self, xt: torch.Tensor, pred_x0: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor, variance_noise: torch.Tensor, **kwargs) -> torch.Tensor: return self.model.scheduler.finish_step(xt, pred_x0, epst, timestep, variance_noise, **kwargs) def get_variance(self, timestep: torch.Tensor) -> torch.Tensor: return self.model.scheduler.get_variance(timestep) def set_timesteps(self, timesteps: int, device: torch.device) -> None: self.model.scheduler.set_timesteps(timesteps, device=device) def encode_image(self, x: torch.Tensor) -> torch.Tensor: pass def decode_image(self, x: torch.Tensor) -> torch.Tensor: pass def encode_prompt(self, prompt: torch.Tensor, negative_prompt=None) -> Tuple[torch.Tensor, torch.Tensor]: pass def get_epst(self, xt: torch.Tensor, t: torch.Tensor, prompt_embeds: torch.Tensor, guidance_scale: Optional[float] = None, **kwargs) -> torch.Tensor: pass def get_image_size(self) -> Tuple[int, int]: return self.model.unet.config.sample_size * self.model.vae_scale_factor def get_noise_shape(self, imsize: Union[int, Tuple[int]], batch_size: int) -> Tuple[int, ...]: if isinstance(imsize, int): imsize = (imsize, imsize) variance_noise_shape = (batch_size, self.model.unet.config.in_channels, imsize[-2], imsize[-1]) return variance_noise_shape def get_latent_shape(self, orig_image_shape: Union[int, Tuple[int, int]]) -> Tuple[int, ...]: if isinstance(orig_image_shape, int): orig_image_shape = (orig_image_shape, orig_image_shape) return (self.model.unet.config.in_channels, orig_image_shape[0] // self.model.vae_scale_factor, orig_image_shape[1] // self.model.vae_scale_factor) def get_pre_kwargs(self, **kwargs) -> dict: return {} class StableDiffWrapper(PipelineWrapper): def __init__(self, scheduler='ddpm', *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.scheduler_type = scheduler try: self.model = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=torch.float16 if self.float16 else torch.float32, token=self.token).to(self.device) except OSError: self.model = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=torch.float16 if self.float16 else torch.float32, token=self.token, force_download=True ).to(self.device) if scheduler == 'ddpm' or 'ddim' in scheduler: eta = 1.0 if 'ddpm' in scheduler else float(scheduler.split('-')[1]) self.model.scheduler = DDIMWrapper(model_id=self.model_id, device=self.device, eta=eta, float16=self.float16, token=self.token) self.model.scheduler.set_timesteps(self.num_timesteps, device=self.device) if self.compile: try: self.model.unet = torch.compile(self.model.unet, mode="reduce-overhead", fullgraph=True) except Exception as e: print(f"Error compiling model: {e}") def encode_image(self, x: torch.Tensor) -> torch.Tensor: return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor) # .float() def decode_image(self, x: torch.Tensor) -> torch.Tensor: if x.device != self.device: orig_device = self.model.vae.device self.model.vae.to(x.device) ret = self.model.vae.decode(x / self.model.vae.config.scaling_factor).sample.clamp(-1, 1) self.model.vae.to(orig_device) return ret return self.model.vae.decode(x / self.model.vae.config.scaling_factor).sample.clamp(-1, 1) def encode_prompt(self, prompt: torch.Tensor, negative_prompt=None) -> Tuple[torch.Tensor, torch.Tensor]: do_cfg = (negative_prompt is not None) or prompt != "" prompt_embeds, negative_prompt_embeds = self.model.encode_prompt( prompt, self.device, 1, do_cfg, negative_prompt, ) if do_cfg: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def get_epst(self, xt: torch.Tensor, t: torch.Tensor, prompt_embeds: torch.Tensor, guidance_scale: Optional[float] = None, return_everything=False, **kwargs): do_cfg = prompt_embeds.shape[0] > 1 xt = torch.cat([xt] * 2) if do_cfg else xt # predict the noise residual noise_pred = self.model.unet(xt, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] # perform guidance if do_cfg: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) return None, noise_pred_uncond, noise_pred_text return None, noise_pred, None class SchedulerWrapper(object): def __init__(self, model_id: str, device: torch.device, float16: bool = False, token: Optional[str] = None, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.model_id = model_id self.device = device self.float16 = float16 self.token = token self.scheduler = None @property def timesteps(self) -> torch.Tensor: return self.scheduler.timesteps def set_timesteps(self, timesteps: int, device: torch.device) -> None: self.scheduler.set_timesteps(timesteps, device=device) if self.scheduler.timesteps[0] == 1000: self.scheduler.timesteps -= 1 def get_x_0_hat(self, xt: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: pass def finish_step(self, xt: torch.Tensor, pred_x0: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor, variance_noise: torch.Tensor, **kwargs) -> torch.Tensor: pass def get_variance(self, timestep: torch.Tensor) -> torch.Tensor: pass class DDIMWrapper(SchedulerWrapper): def __init__(self, eta, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.scheduler = DDIMScheduler.from_pretrained( self.model_id, subfolder="scheduler", torch_dtype=torch.float16 if self.float16 else torch.float32, token=self.token, device=self.device, timestep_spacing='linspace') self.eta = eta def get_x_0_hat(self, xt: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: # compute alphas, betas alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t # compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.scheduler.config.prediction_type == 'epsilon': pred_original_sample = (xt - beta_prod_t ** (0.5) * epst) / alpha_prod_t ** (0.5) elif self.scheduler.config.prediction_type == 'v_prediction': pred_original_sample = (alpha_prod_t ** 0.5) * xt - (beta_prod_t ** 0.5) * epst return pred_original_sample def finish_step(self, xt: torch.Tensor, pred_x0: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor, variance_noise: torch.Tensor, eta=None) -> torch.Tensor: if eta is None: eta = self.eta prev_timestep = timestep - self.scheduler.config.num_train_timesteps // \ self.scheduler.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = self._get_alpha_prod_t_prev(prev_timestep) beta_prod_t = 1 - alpha_prod_t # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self.get_variance(timestep) std_dev_t = eta * variance ** (0.5) # std_dev_t = eta * variance ** (0.5) # Take care of asymetric reverse process (asyrp) if self.scheduler.config.prediction_type == 'epsilon': model_output_direction = epst elif self.scheduler.config.prediction_type == 'v_prediction': model_output_direction = (alpha_prod_t**0.5) * epst + (beta_prod_t**0.5) * xt # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_x0 + pred_sample_direction # 8. Add noice if eta > 0 if eta > 0: sigma_z = std_dev_t * variance_noise prev_sample = prev_sample + sigma_z return prev_sample def get_variance(self, timestep: torch.Tensor) -> torch.Tensor: prev_timestep = timestep - self.scheduler.config.num_train_timesteps // \ self.scheduler.num_inference_steps variance = self.scheduler._get_variance(timestep, prev_timestep) return variance def _get_alpha_prod_t_prev(self, prev_timestep: torch.Tensor) -> torch.Tensor: return self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \ else self.scheduler.final_alpha_cumprod def load_model(model_id: str, timesteps: int, device: torch.device, blip: bool = False, float16: bool = False, token: Optional[str] = None, compile: bool = True, blip_model="Salesforce/blip2-opt-2.7b-coco", scheduler: str = 'ddpm') -> PipelineWrapper: pipeline = StableDiffWrapper(model_id=model_id, timesteps=timesteps, device=device, scheduler=scheduler, float16=float16, token=token, compile=compile) pipeline = pipeline.to(device) if blip: pipeline.blip_processor = Blip2Processor.from_pretrained(blip_model) try: print(device if torch.cuda.get_device_properties(0).total_memory/(1024**3) > 18 else 'cpu') pipeline.blip_model = Blip2ForConditionalGeneration.from_pretrained( blip_model,).to(device if torch.cuda.get_device_properties(0).total_memory/(1024**3) > 18 else 'cpu') except OSError: pipeline.blip_model = Blip2ForConditionalGeneration.from_pretrained( blip_model, force_download=True).to(device if torch.cuda.get_device_properties(0).total_memory/(1024**3) > 18 else 'cpu') pipeline.blip_max_words = 32 image_size = pipeline.get_image_size() return pipeline, image_size