from typing import Any, Dict, Optional from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.schedulers import KarrasDiffusionSchedulers import numpy import torch import torch.nn as nn import torch.utils.checkpoint import torch.distributed import transformers from collections import OrderedDict from PIL import Image from torchvision import transforms from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput randn_tensor = torch.randn import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, EulerAncestralDiscreteScheduler, UNet2DConditionModel, ImagePipelineOutput, ) from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import ( Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0, ) from diffusers.utils.import_utils import is_xformers_available import spaces def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def to_rgb_image(maybe_rgba: Image.Image): if maybe_rgba.mode == "RGB": return maybe_rgba elif maybe_rgba.mode == "RGBA": rgba = maybe_rgba img = numpy.random.randint( 255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8 ) img = Image.fromarray(img, "RGB") img.paste(rgba, mask=rgba.getchannel("A")) return img else: raise ValueError("Unsupported image type.", maybe_rgba.mode) class ReferenceOnlyAttnProc(torch.nn.Module): def __init__(self, chained_proc, enabled=False, name=None) -> None: super().__init__() self.enabled = enabled self.chained_proc = chained_proc self.name = name def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict: dict = None, is_cfg_guidance=False, ) -> Any: if encoder_hidden_states is None: encoder_hidden_states = hidden_states if self.enabled and is_cfg_guidance: res0 = self.chained_proc( attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask ) hidden_states = hidden_states[1:] encoder_hidden_states = encoder_hidden_states[1:] if self.enabled: if mode == "w": ref_dict[self.name] = encoder_hidden_states elif mode == "r": encoder_hidden_states = torch.cat( [encoder_hidden_states, ref_dict.pop(self.name)], dim=1 ) elif mode == "m": encoder_hidden_states = torch.cat( [encoder_hidden_states, ref_dict[self.name]], dim=1 ) elif mode == "c": encoder_hidden_states = torch.cat( [encoder_hidden_states, encoder_hidden_states], dim=1 ) else: assert False, mode res = self.chained_proc( attn, hidden_states, encoder_hidden_states, attention_mask ) if self.enabled and is_cfg_guidance: res = torch.cat([res0, res]) return res class RefOnlyNoisedUNet(torch.nn.Module): def __init__( self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler, ) -> None: super().__init__() self.unet = unet self.train_sched = train_sched self.val_sched = val_sched unet_lora_attn_procs = dict() for name, _ in unet.attn_processors.items(): if torch.__version__ >= "2.0": default_attn_proc = AttnProcessor2_0() elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor() else: default_attn_proc = AttnProcessor() unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( default_attn_proc, enabled=name.endswith("attn1.processor"), name=name ) unet.set_attn_processor(unet_lora_attn_procs) def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.unet, name) def forward_cond( self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs, ): if is_cfg_guidance: encoder_hidden_states = encoder_hidden_states[1:] class_labels = class_labels[1:] self.unet( noisy_cond_lat, timestep, encoder_hidden_states=encoder_hidden_states, class_labels=class_labels, cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), **kwargs, ) def forward( self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs, down_block_res_samples=None, mid_block_res_sample=None, forward_cond_state=True, **kwargs, ): cond_lat = cross_attention_kwargs["cond_lat"] is_cfg_guidance = cross_attention_kwargs.get("is_cfg_guidance", False) noise = torch.randn_like(cond_lat) if self.training: noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) noisy_cond_lat = self.train_sched.scale_model_input( noisy_cond_lat, timestep ) else: noisy_cond_lat = self.val_sched.add_noise( cond_lat, noise, timestep.reshape(-1) ) noisy_cond_lat = self.val_sched.scale_model_input( noisy_cond_lat, timestep.reshape(-1) ) ref_dict = {} if "dont_forward_cond_state" not in cross_attention_kwargs.keys(): self.forward_cond( noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs, ) mode = "r" else: mode = "c" weight_dtype = self.unet.dtype return self.unet( sample, timestep, encoder_hidden_states, *args, class_labels=class_labels, cross_attention_kwargs=dict( mode=mode, ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance ), down_block_additional_residuals=[ sample.to(dtype=weight_dtype) for sample in down_block_res_samples ] if down_block_res_samples is not None else None, mid_block_additional_residual=( mid_block_res_sample.to(dtype=weight_dtype) if mid_block_res_sample is not None else None ), **kwargs, ) def scale_latents(latents): latents = (latents - 0.22) * 0.75 return latents def unscale_latents(latents): latents = latents / 0.75 + 0.22 return latents def scale_image(image): image = image * 0.5 / 0.8 return image def unscale_image(image): image = image / 0.5 * 0.8 return image class DepthControlUNet(torch.nn.Module): def __init__( self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0, ) -> None: super().__init__() self.unet = unet if controlnet is None: self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet) else: self.controlnet = controlnet DefaultAttnProc = AttnProcessor2_0 if is_xformers_available(): DefaultAttnProc = XFormersAttnProcessor self.controlnet.set_attn_processor(DefaultAttnProc()) self.conditioning_scale = conditioning_scale def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.unet, name) def forward( self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs, ): cross_attention_kwargs = dict(cross_attention_kwargs) control_depth = cross_attention_kwargs.pop("control_depth") down_block_res_samples, mid_block_res_sample = self.controlnet( sample, timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=control_depth, conditioning_scale=self.conditioning_scale, return_dict=False, ) return self.unet( sample, timestep, encoder_hidden_states=encoder_hidden_states, down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, cross_attention_kwargs=cross_attention_kwargs, ) class ModuleListDict(torch.nn.Module): def __init__(self, procs: dict) -> None: super().__init__() self.keys = sorted(procs.keys()) self.values = torch.nn.ModuleList(procs[k] for k in self.keys) def __getitem__(self, key): return self.values[self.keys.index(key)] class SuperNet(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys())) self.layers = torch.nn.ModuleList(state_dict.values()) self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} # .processor for unet, .self_attn for text encoder self.split_keys = [".processor", ".self_attn"] # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` def map_to(module, state_dict, *args, **kwargs): new_state_dict = {} for key, value in state_dict.items(): num = int(key.split(".")[1]) # 0 is always "layers" new_key = key.replace(f"layers.{num}", module.mapping[num]) new_state_dict[new_key] = value return new_state_dict def remap_key(key, state_dict): for k in self.split_keys: if k in key: return key.split(k)[0] + k return key.split(".")[0] def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: replace_key = remap_key(key, state_dict) new_key = key.replace( replace_key, f"layers.{module.rev_mapping[replace_key]}" ) state_dict[new_key] = state_dict[key] del state_dict[key] self._register_state_dict_hook(map_to) self._register_load_state_dict_pre_hook(map_from, with_module=True) class Zero123PlusPipeline(diffusers.StableDiffusionPipeline): tokenizer: transformers.CLIPTokenizer text_encoder: transformers.CLIPTextModel vision_encoder: transformers.CLIPVisionModelWithProjection feature_extractor_clip: transformers.CLIPImageProcessor unet: UNet2DConditionModel scheduler: diffusers.schedulers.KarrasDiffusionSchedulers vae: AutoencoderKL ramping: nn.Linear feature_extractor_vae: transformers.CLIPImageProcessor depth_transforms_multi = transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] ) def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, vision_encoder: transformers.CLIPVisionModelWithProjection, feature_extractor_clip: CLIPImageProcessor, feature_extractor_vae: CLIPImageProcessor, ramping_coefficients: Optional[list] = None, safety_checker=None, ): DiffusionPipeline.__init__(self) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, vision_encoder=vision_encoder, feature_extractor_clip=feature_extractor_clip, feature_extractor_vae=feature_extractor_vae, ) self.register_to_config(ramping_coefficients=ramping_coefficients) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def prepare(self): train_sched = DDPMScheduler.from_config(self.scheduler.config) if isinstance(self.unet, UNet2DConditionModel): self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval() def add_controlnet( self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0, ): self.prepare() self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale) return SuperNet(OrderedDict([("controlnet", self.unet.controlnet)])) def encode_condition_image(self, image: torch.Tensor): image = self.vae.encode(image).latent_dist.sample() return image @spaces.GPU(duration=60) @torch.no_grad() def edit_latents( self, image_guidance: Image.Image, multiview_source_image: Image.Image = None, edit_strength: float = 1.0, prompt="", *args, guidance_scale=0.0, output_type: Optional[str] = "pil", width=640, height=960, num_inference_steps=28, return_dict=True, **kwargs, ): self.prepare() if image_guidance is None: raise ValueError( "Inputting embeddings not supported for this pipeline. Please pass an image." ) if multiview_source_image is None: raise ValueError("Multiview source image is required for this pipeline.") assert not isinstance(image_guidance, torch.Tensor) assert not isinstance(multiview_source_image, torch.Tensor) image_guidance = to_rgb_image(image_guidance) image_source = to_rgb_image(multiview_source_image) image_guidance_1 = self.feature_extractor_vae( images=image_guidance, return_tensors="pt" ).pixel_values image_guidance_2 = self.feature_extractor_clip( images=image_source, return_tensors="pt" ).pixel_values image_guidance = image_guidance_1.to( device=self.vae.device, dtype=self.vae.dtype ) image_guidance_2 = image_guidance_2.to( device=self.vae.device, dtype=self.vae.dtype ) cond_lat = self.encode_condition_image(image_guidance) # if guidance_scale > 1: negative_lat = self.encode_condition_image(torch.zeros_like(image_guidance)) cond_lat = torch.cat([negative_lat, cond_lat]) encoded = self.vision_encoder(image_guidance_2, output_hidden_states=False) global_embeds = encoded.image_embeds global_embeds = global_embeds.unsqueeze(-2) if hasattr(self, "encode_prompt"): encoder_hidden_states = self.encode_prompt(prompt, self.device, 1, False)[0] else: encoder_hidden_states = self._encode_prompt(prompt, self.device, 1, False) ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) encoder_hidden_states = encoder_hidden_states + global_embeds * ramp cak = dict(cond_lat=cond_lat) mv_image = ( torch.from_numpy(numpy.array(multiview_source_image)).to(self.vae.device) / 255.0 ) mv_image = ( mv_image.permute(2, 0, 1) .to(self.vae.device) .to(self.vae.dtype) .unsqueeze(0) ) latents = ( self.vae.encode(mv_image * 2.0 - 1.0).latent_dist.sample() * self.vae.config.scaling_factor ) latents: torch.Tensor = ( super() .__call__( None, *args, cross_attention_kwargs=cak, guidance_scale=guidance_scale, num_images_per_prompt=1, prompt_embeds=encoder_hidden_states, num_inference_steps=num_inference_steps, output_type="latent", width=width, height=height, latents=latents, edit_strength=edit_strength, **kwargs, ) .images ) latents = unscale_latents(latents) if not output_type == "latent": image = unscale_image( self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False )[0] ) else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) @torch.no_grad() def encode_target_images(self, images): dtype = next(self.vae.parameters()).dtype # equals to scaling images to [-1, 1] first and then call scale_image images = (images - 0.5) / 0.8 # [-0.625, 0.625] posterior = self.vae.encode(images.to(dtype)).latent_dist latents = posterior.sample() * self.vae.config.scaling_factor latents = scale_latents(latents) return latents @spaces.GPU(duration=60) @torch.no_grad() def sdedit( self, image, *args, cond_image: Image.Image = None, output_type: Optional[str] = "pil", width=640, height=960, num_inference_steps=75, edit_strength=1.0, return_dict=True, guidance_scale=0.0, **kwargs, ): self.prepare() if image is None: raise ValueError( "Inputting embeddings not supported for this pipeline. Please pass an image." ) assert not isinstance(image, torch.Tensor) image = to_rgb_image(image) # cond_lat = self.encode_condition_image(image_guidance) if hasattr(self, "encode_prompt"): encoder_hidden_states = self.encode_prompt([""], self.device, 1, False)[0] else: encoder_hidden_states = self._encode_prompt([""], self.device, 1, False) # negative_lat = self.encode_condition_image(torch.zeros_like(image_guidance)) # cond_lat = torch.cat([negative_lat, cond_lat]) # encoded = self.vision_encoder(image_guidance_2, output_hidden_states=False) # global_embeds = encoded.image_embeds # global_embeds = global_embeds.unsqueeze(-2) # prompt = "" # ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) # encoder_hidden_states = encoder_hidden_states + global_embeds * ramp # cak = dict(cond_lat=cond_lat) image = torch.from_numpy(numpy.array(image)).to(self.vae.device) / 255.0 image = image.permute(2, 0, 1).unsqueeze(0) if self.vae.dtype == torch.float16: image = image.half() # image = image.permute(2, 0, 1).to(self.vae.device).to(self.vae.dtype).unsqueeze(0) latents = self.encode_target_images(image) if cond_image is not None: cond_image = to_rgb_image(cond_image) cond_image = ( torch.from_numpy(numpy.array(cond_image)).to(self.vae.device) / 255.0 ) cond_image = cond_image.permute(2, 0, 1).unsqueeze(0) if self.vae.dtype == torch.float16: cond_image = cond_image.half() cond_lat = self.encode_condition_image(cond_image) else: cond_lat = self.encode_condition_image(torch.zeros_like(image)).to( self.vae.device ) cak = dict(cond_lat=cond_lat, dont_forward_cond_state=True) latents = self.forward_sdedit( latents, cross_attention_kwargs=cak, guidance_scale=guidance_scale, num_images_per_prompt=1, prompt_embeds=encoder_hidden_states, num_inference_steps=num_inference_steps, output_type="latent", width=width, height=height, edit_strength=edit_strength, **kwargs, ).images # latents = unscale_latents(latents) if not output_type == "latent": image = unscale_image( self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False )[0] ) else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) @spaces.GPU(duration=60) @torch.no_grad() def refine( self, image: Image.Image = None, edit_image: Image.Image = None, prompt: Optional[str] = "", *args, output_type: Optional[str] = "pil", width=640, height=960, num_inference_steps=28, edit_strength=1.0, return_dict=True, guidance_scale=4.0, **kwargs, ): self.prepare() if image is None: raise ValueError( "Inputting embeddings not supported for this pipeline. Please pass an image." ) assert not isinstance(image, torch.Tensor) image = to_rgb_image(image) # cond_lat = self.encode_condition_image(image_guidance) if hasattr(self, "encode_prompt"): encoder_hidden_states = self.encode_prompt(prompt, self.device, 1, False)[0] else: encoder_hidden_states = self._encode_prompt(prompt, self.device, 1, False) # negative_lat = self.encode_condition_image(torch.zeros_like(image_guidance)) # cond_lat = torch.cat([negative_lat, cond_lat]) # encoded = self.vision_encoder(image_guidance_2, output_hidden_states=False) # global_embeds = encoded.image_embeds # global_embeds = global_embeds.unsqueeze(-2) # prompt = "" # ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) # encoder_hidden_states = encoder_hidden_states + global_embeds * ramp # cak = dict(cond_lat=cond_lat) latents_edit = None if edit_image is not None: edit_image = to_rgb_image(edit_image) edit_image = ( torch.from_numpy(numpy.array(edit_image)).to(self.vae.device) / 255.0 ) edit_image = edit_image.permute(2, 0, 1).unsqueeze(0) if self.vae.dtype == torch.float16: edit_image = edit_image.half() latents_edit = self.encode_target_images(edit_image) image = torch.from_numpy(numpy.array(image)).to(self.vae.device) / 255.0 image = image.permute(2, 0, 1).unsqueeze(0) if self.vae.dtype == torch.float16: image = image.half() # image = torch.nn.functional.interpolate( # image, (height*4, width*4), mode="bilinear", align_corners=False) # image = image[...,:320,:320] height, width = image.shape[-2:] # image = image[...,:640,:] # image[...,:320,:] = torch.ones_like(image[...,:320,:]) # image = image.permute(2, 0, 1).to(self.vae.device).to(self.vae.dtype).unsqueeze(0) # height = height * 4 # width = width * 4 latents = self.encode_target_images(image) # latents[...,-40:,:] = torch.randn_like(latents[...,-40:,:]) cond_lat = self.encode_condition_image(torch.zeros_like(image)).to( self.vae.device ) cak = dict(cond_lat=cond_lat, dont_forward_cond_state=True) latents = self.forward_pipeline( latents_edit, latents, cross_attention_kwargs=cak, guidance_scale=guidance_scale, num_images_per_prompt=1, prompt_embeds=encoder_hidden_states, num_inference_steps=num_inference_steps, output_type="latent", width=width, height=height, edit_strength=edit_strength, **kwargs, ).images # latents = unscale_latents(latents) if not output_type == "latent": image = unscale_image( self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False )[0] ) else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, timestep=None, ): shape = ( batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor( shape, generator=generator, device=device, dtype=dtype ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma else: if timestep is None: raise ValueError( "When passing `latents` you also need to pass `timestep`." ) latents = latents.to(device) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents latents = self.scheduler.add_noise(latents, noise, timestep) return latents @torch.no_grad() def forward_sdedit( self, latents: torch.Tensor, cross_attention_kwargs: dict, guidance_scale: float, num_images_per_prompt: int, prompt_embeds, num_inference_steps: int, output_type: str, width: int, height: int, edit_strength: float = 1.0, ): height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor batch_size = prompt_embeds.shape[0] generator = torch.Generator(device=latents.device) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( None, device, num_images_per_prompt, do_classifier_free_guidance, None, prompt_embeds=prompt_embeds, negative_prompt_embeds=None, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) # self.scheduler.timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps timesteps = reversed(reversed(timesteps)[: int(edit_strength * len(timesteps))]) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, timesteps[0:1], ) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0) # if do_classifier_free_guidance: # cond_latent = cond_latent.expand(batch_size * 2, -1, -1, -1) # 7. Denoising loop num_warmup_steps = 0 with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * 2) if do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t ) # latent_model_input = # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # exit(0)/ # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): progress_bar.update() latents = unscale_latents(latents) if not output_type == "latent": image = self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False )[0] image, has_nsfw_concept = self.run_safety_checker( image, device, prompt_embeds.dtype ) else: image = latents if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess( image, output_type=output_type, do_denormalize=do_denormalize ) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() return StableDiffusionPipelineOutput( images=image, nsfw_content_detected=has_nsfw_concept ) @torch.no_grad() def forward_pipeline( self, latents: torch.Tensor, cond_latent: torch.Tensor, cross_attention_kwargs: dict, guidance_scale: float, num_images_per_prompt: int, prompt_embeds, num_inference_steps: int, output_type: str, width: int, height: int, edit_strength: float = 1.0, ): height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor batch_size = 1 generator = torch.Generator(device=cond_latent.device) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( None, device, num_images_per_prompt, do_classifier_free_guidance, None, prompt_embeds=prompt_embeds, negative_prompt_embeds=None, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) # self.scheduler.timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps timesteps = reversed(reversed(timesteps)[: int(edit_strength * len(timesteps))]) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels // 2 latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, timesteps[0:1], ) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0) if do_classifier_free_guidance: cond_latent = cond_latent.expand(batch_size * 2, -1, -1, -1) # 7. Denoising loop num_warmup_steps = 0 with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * 2) if do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t ) latent_model_input = torch.cat([latent_model_input, cond_latent], dim=1) # latent_model_input = latent_model_input.half() # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): progress_bar.update() latents = unscale_latents(latents) if not output_type == "latent": image = self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False )[0] image, has_nsfw_concept = self.run_safety_checker( image, device, prompt_embeds.dtype ) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess( image, output_type=output_type, do_denormalize=do_denormalize ) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() return StableDiffusionPipelineOutput( images=image, nsfw_content_detected=has_nsfw_concept ) @spaces.GPU(duration=60) @torch.no_grad() def __call__( self, image: Image.Image = None, source_image: Image.Image = None, prompt="", *args, num_images_per_prompt: Optional[int] = 1, guidance_scale=4.0, depth_image: Image.Image = None, output_type: Optional[str] = "pil", width=640, height=960, num_inference_steps=28, return_dict=True, **kwargs, ): self.prepare() if image is None: raise ValueError( "Inputting embeddings not supported for this pipeline. Please pass an image." ) assert not isinstance(image, torch.Tensor) image = to_rgb_image(image) image_1 = self.feature_extractor_vae( images=image, return_tensors="pt" ).pixel_values image_2 = self.feature_extractor_clip( images=image, return_tensors="pt" ).pixel_values # image_source = to_rgb_image(source_image) # image_source_latents = self.feature_extractor_vae(images=image_source, return_tensors="pt") if depth_image is not None and hasattr(self.unet, "controlnet"): depth_image = to_rgb_image(depth_image) depth_image = self.depth_transforms_multi(depth_image).to( device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype ) image = image_1.to(device=self.vae.device, dtype=self.vae.dtype) image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) cond_lat = self.encode_condition_image(image) if guidance_scale > 1: negative_lat = self.encode_condition_image(torch.zeros_like(image)) cond_lat = torch.cat([negative_lat, cond_lat]) encoded = self.vision_encoder(image_2, output_hidden_states=False) global_embeds = encoded.image_embeds global_embeds = global_embeds.unsqueeze(-2) if hasattr(self, "encode_prompt"): encoder_hidden_states = self.encode_prompt( prompt, self.device, num_images_per_prompt, False )[0] else: encoder_hidden_states = self._encode_prompt( prompt, self.device, num_images_per_prompt, False ) ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) encoder_hidden_states = encoder_hidden_states + global_embeds * ramp cak = dict(cond_lat=cond_lat) if hasattr(self.unet, "controlnet"): cak["control_depth"] = depth_image latents: torch.Tensor = ( super() .__call__( None, *args, cross_attention_kwargs=cak, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, prompt_embeds=encoder_hidden_states, num_inference_steps=num_inference_steps, output_type="latent", width=width, height=height, latents=None, **kwargs, ) .images ) latents = unscale_latents(latents) if not output_type == "latent": image = unscale_image( self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False )[0] ) else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image)