Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from diffusers import DDIMScheduler, StableDiffusionPipeline | |
from typing import Optional, Tuple, Union | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
class PipelineWrapper(torch.nn.Module): | |
def __init__(self, model_id: str, | |
timesteps: int, | |
device: torch.device, | |
float16: bool = False, | |
compile: bool = True, | |
token: Optional[str] = None, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.model_id = model_id | |
self.num_timesteps = timesteps | |
self.device = device | |
self.float16 = float16 | |
self.token = token | |
self.compile = compile | |
self.model = None | |
# def get_sigma(self, timestep: int) -> float: | |
# sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1) | |
# return sqrt_recipm1_alphas_cumprod[timestep] | |
def timesteps(self) -> torch.Tensor: | |
return self.model.scheduler.timesteps | |
def dtype(self) -> torch.dtype: | |
if self.model is None: | |
raise AttributeError("Model is not initialized.") | |
return self.model.unet.dtype | |
def get_x_0_hat(self, xt: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: | |
return self.model.scheduler.get_x_0_hat(xt, epst, timestep) | |
def finish_step(self, xt: torch.Tensor, pred_x0: torch.Tensor, epst: torch.Tensor, | |
timestep: torch.Tensor, variance_noise: torch.Tensor, | |
**kwargs) -> torch.Tensor: | |
return self.model.scheduler.finish_step(xt, pred_x0, epst, timestep, variance_noise, **kwargs) | |
def get_variance(self, timestep: torch.Tensor) -> torch.Tensor: | |
return self.model.scheduler.get_variance(timestep) | |
def set_timesteps(self, timesteps: int, device: torch.device) -> None: | |
self.model.scheduler.set_timesteps(timesteps, device=device) | |
def encode_image(self, x: torch.Tensor) -> torch.Tensor: | |
pass | |
def decode_image(self, x: torch.Tensor) -> torch.Tensor: | |
pass | |
def encode_prompt(self, prompt: torch.Tensor, negative_prompt=None) -> Tuple[torch.Tensor, torch.Tensor]: | |
pass | |
def get_epst(self, xt: torch.Tensor, t: torch.Tensor, prompt_embeds: torch.Tensor, | |
guidance_scale: Optional[float] = None, **kwargs) -> torch.Tensor: | |
pass | |
def get_image_size(self) -> Tuple[int, int]: | |
return self.model.unet.config.sample_size * self.model.vae_scale_factor | |
def get_noise_shape(self, imsize: Union[int, Tuple[int]], batch_size: int) -> Tuple[int, ...]: | |
if isinstance(imsize, int): | |
imsize = (imsize, imsize) | |
variance_noise_shape = (batch_size, | |
self.model.unet.config.in_channels, | |
imsize[-2], | |
imsize[-1]) | |
return variance_noise_shape | |
def get_latent_shape(self, orig_image_shape: Union[int, Tuple[int, int]]) -> Tuple[int, ...]: | |
if isinstance(orig_image_shape, int): | |
orig_image_shape = (orig_image_shape, orig_image_shape) | |
return (self.model.unet.config.in_channels, | |
orig_image_shape[0] // self.model.vae_scale_factor, | |
orig_image_shape[1] // self.model.vae_scale_factor) | |
def get_pre_kwargs(self, **kwargs) -> dict: | |
return {} | |
class StableDiffWrapper(PipelineWrapper): | |
def __init__(self, scheduler='ddpm', *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.scheduler_type = scheduler | |
try: | |
self.model = StableDiffusionPipeline.from_pretrained( | |
self.model_id, | |
torch_dtype=torch.float16 if self.float16 else torch.float32, | |
token=self.token).to(self.device) | |
except OSError: | |
self.model = StableDiffusionPipeline.from_pretrained( | |
self.model_id, | |
torch_dtype=torch.float16 if self.float16 else torch.float32, | |
token=self.token, force_download=True | |
).to(self.device) | |
if scheduler == 'ddpm' or 'ddim' in scheduler: | |
eta = 1.0 if 'ddpm' in scheduler else float(scheduler.split('-')[1]) | |
self.model.scheduler = DDIMWrapper(model_id=self.model_id, device=self.device, | |
eta=eta, | |
float16=self.float16, token=self.token) | |
self.model.scheduler.set_timesteps(self.num_timesteps, device=self.device) | |
if self.compile: | |
try: | |
self.model.unet = torch.compile(self.model.unet, mode="reduce-overhead", fullgraph=True) | |
except Exception as e: | |
print(f"Error compiling model: {e}") | |
def encode_image(self, x: torch.Tensor) -> torch.Tensor: | |
return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor) # .float() | |
def decode_image(self, x: torch.Tensor) -> torch.Tensor: | |
if x.device != self.device: | |
orig_device = self.model.vae.device | |
self.model.vae.to(x.device) | |
ret = self.model.vae.decode(x / self.model.vae.config.scaling_factor).sample.clamp(-1, 1) | |
self.model.vae.to(orig_device) | |
return ret | |
return self.model.vae.decode(x / self.model.vae.config.scaling_factor).sample.clamp(-1, 1) | |
def encode_prompt(self, prompt: torch.Tensor, negative_prompt=None) -> Tuple[torch.Tensor, torch.Tensor]: | |
do_cfg = (negative_prompt is not None) or prompt != "" | |
prompt_embeds, negative_prompt_embeds = self.model.encode_prompt( | |
prompt, self.device, 1, | |
do_cfg, | |
negative_prompt, | |
) | |
if do_cfg: | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
return prompt_embeds | |
def get_epst(self, xt: torch.Tensor, t: torch.Tensor, prompt_embeds: torch.Tensor, | |
guidance_scale: Optional[float] = None, return_everything=False, **kwargs): | |
do_cfg = prompt_embeds.shape[0] > 1 | |
xt = torch.cat([xt] * 2) if do_cfg else xt | |
# predict the noise residual | |
noise_pred = self.model.unet(xt, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] | |
# perform guidance | |
if do_cfg: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
return None, noise_pred_uncond, noise_pred_text | |
return None, noise_pred, None | |
class SchedulerWrapper(object): | |
def __init__(self, model_id: str, device: torch.device, | |
float16: bool = False, token: Optional[str] = None, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.model_id = model_id | |
self.device = device | |
self.float16 = float16 | |
self.token = token | |
self.scheduler = None | |
def timesteps(self) -> torch.Tensor: | |
return self.scheduler.timesteps | |
def set_timesteps(self, timesteps: int, device: torch.device) -> None: | |
self.scheduler.set_timesteps(timesteps, device=device) | |
if self.scheduler.timesteps[0] == 1000: | |
self.scheduler.timesteps -= 1 | |
def get_x_0_hat(self, xt: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: | |
pass | |
def finish_step(self, xt: torch.Tensor, pred_x0: torch.Tensor, epst: torch.Tensor, | |
timestep: torch.Tensor, variance_noise: torch.Tensor, | |
**kwargs) -> torch.Tensor: | |
pass | |
def get_variance(self, timestep: torch.Tensor) -> torch.Tensor: | |
pass | |
class DDIMWrapper(SchedulerWrapper): | |
def __init__(self, eta, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.scheduler = DDIMScheduler.from_pretrained( | |
self.model_id, subfolder="scheduler", | |
torch_dtype=torch.float16 if self.float16 else torch.float32, | |
token=self.token, | |
device=self.device, timestep_spacing='linspace') | |
self.eta = eta | |
def get_x_0_hat(self, xt: torch.Tensor, epst: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: | |
# compute alphas, betas | |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
# compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
if self.scheduler.config.prediction_type == 'epsilon': | |
pred_original_sample = (xt - beta_prod_t ** (0.5) * epst) / alpha_prod_t ** (0.5) | |
elif self.scheduler.config.prediction_type == 'v_prediction': | |
pred_original_sample = (alpha_prod_t ** 0.5) * xt - (beta_prod_t ** 0.5) * epst | |
return pred_original_sample | |
def finish_step(self, xt: torch.Tensor, pred_x0: torch.Tensor, epst: torch.Tensor, | |
timestep: torch.Tensor, variance_noise: torch.Tensor, | |
eta=None) -> torch.Tensor: | |
if eta is None: | |
eta = self.eta | |
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // \ | |
self.scheduler.num_inference_steps | |
# 2. compute alphas, betas | |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | |
alpha_prod_t_prev = self._get_alpha_prod_t_prev(prev_timestep) | |
beta_prod_t = 1 - alpha_prod_t | |
# 5. compute variance: "sigma_t(η)" -> see formula (16) | |
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | |
variance = self.get_variance(timestep) | |
std_dev_t = eta * variance ** (0.5) | |
# std_dev_t = eta * variance ** (0.5) | |
# Take care of asymetric reverse process (asyrp) | |
if self.scheduler.config.prediction_type == 'epsilon': | |
model_output_direction = epst | |
elif self.scheduler.config.prediction_type == 'v_prediction': | |
model_output_direction = (alpha_prod_t**0.5) * epst + (beta_prod_t**0.5) * xt | |
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction | |
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
prev_sample = alpha_prod_t_prev ** (0.5) * pred_x0 + pred_sample_direction | |
# 8. Add noice if eta > 0 | |
if eta > 0: | |
sigma_z = std_dev_t * variance_noise | |
prev_sample = prev_sample + sigma_z | |
return prev_sample | |
def get_variance(self, timestep: torch.Tensor) -> torch.Tensor: | |
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // \ | |
self.scheduler.num_inference_steps | |
variance = self.scheduler._get_variance(timestep, prev_timestep) | |
return variance | |
def _get_alpha_prod_t_prev(self, prev_timestep: torch.Tensor) -> torch.Tensor: | |
return self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \ | |
else self.scheduler.final_alpha_cumprod | |
def load_model(model_id: str, timesteps: int, | |
device: torch.device, blip: bool = False, | |
float16: bool = False, token: Optional[str] = None, | |
compile: bool = True, | |
blip_model="Salesforce/blip2-opt-2.7b-coco", scheduler: str = 'ddpm') -> PipelineWrapper: | |
pipeline = StableDiffWrapper(model_id=model_id, timesteps=timesteps, device=device, | |
scheduler=scheduler, | |
float16=float16, token=token, compile=compile) | |
pipeline = pipeline.to(device) | |
if blip: | |
pipeline.blip_processor = Blip2Processor.from_pretrained(blip_model) | |
try: | |
print(device if torch.cuda.get_device_properties(0).total_memory/(1024**3) > 18 else 'cpu') | |
pipeline.blip_model = Blip2ForConditionalGeneration.from_pretrained( | |
blip_model,).to(device if torch.cuda.get_device_properties(0).total_memory/(1024**3) > 18 else 'cpu') | |
except OSError: | |
pipeline.blip_model = Blip2ForConditionalGeneration.from_pretrained( | |
blip_model, force_download=True).to(device if torch.cuda.get_device_properties(0).total_memory/(1024**3) > 18 else 'cpu') | |
pipeline.blip_max_words = 32 | |
image_size = pipeline.get_image_size() | |
return pipeline, image_size | |