linoyts HF staff commited on
Commit
6f5762b
1 Parent(s): 217a63b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1866 -3
app.py CHANGED
@@ -1,3 +1,1866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
@@ -7,7 +1870,7 @@ import time
7
  import numpy as np
8
  import cv2
9
  from PIL import Image
10
- from ledits.pipeline_leditspp_stable_diffusion_xl import LEditsPPPipelineStableDiffusionXL
11
 
12
  def HWC3(x):
13
  assert x.dtype == np.uint8
@@ -162,7 +2025,7 @@ def update_y(x,y,prompt, seed, steps,
162
  return image
163
 
164
  @spaces.GPU
165
- def invert(image, num_inversion_steps=50, skip=0.3):
166
  image = image.resize((512,512))
167
  init_latents,zs = clip_slider_inv.pipe.invert(
168
  source_prompt = "",
@@ -334,7 +2197,7 @@ with gr.Blocks(css=css) as demo:
334
  inputs=[slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2],
335
  outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image])
336
 
337
- image_inv.change(fn=reset_do_inversion, outputs=[do_inversion]).then(fn=invert, inputs=[image_inv], outputs=[init_latents,zs])
338
  submit_inv.click(fn=generate,
339
  inputs=[slider_x_inv, slider_y_inv, prompt_inv, seed_inv, iterations_inv, steps_inv, guidance_scale_inv, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, img2img_type_inv, image, controlnet_conditioning_scale, ip_adapter_scale ,edit_threshold, edit_guidance_scale, init_latents, zs],
340
  outputs=[x_inv, y_inv, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image_inv])
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import math
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from transformers import (
22
+ CLIPImageProcessor,
23
+ CLIPTextModel,
24
+ CLIPTextModelWithProjection,
25
+ CLIPTokenizer,
26
+ CLIPVisionModelWithProjection,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import (
31
+ FromSingleFileMixin,
32
+ IPAdapterMixin,
33
+ StableDiffusionXLLoraLoaderMixin,
34
+ TextualInversionLoaderMixin,
35
+ )
36
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
37
+ from diffusers.models.attention_processor import (
38
+ Attention,
39
+ AttnProcessor,
40
+ AttnProcessor2_0,
41
+ XFormersAttnProcessor,
42
+ )
43
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
44
+ from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler
45
+ from diffusers.utils import (
46
+ USE_PEFT_BACKEND,
47
+ is_invisible_watermark_available,
48
+ is_torch_xla_available,
49
+ logging,
50
+ replace_example_docstring,
51
+ scale_lora_layers,
52
+ unscale_lora_layers,
53
+ )
54
+ from diffusers.utils.torch_utils import randn_tensor
55
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
56
+ from ledits.pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
57
+
58
+
59
+ if is_invisible_watermark_available():
60
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
61
+
62
+ if is_torch_xla_available():
63
+ import torch_xla.core.xla_model as xm
64
+
65
+ XLA_AVAILABLE = True
66
+ else:
67
+ XLA_AVAILABLE = False
68
+
69
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
70
+
71
+ EXAMPLE_DOC_STRING = """
72
+ Examples:
73
+ ```py
74
+ >>> import torch
75
+ >>> import PIL
76
+ >>> import requests
77
+ >>> from io import BytesIO
78
+
79
+ >>> from diffusers import LEditsPPPipelineStableDiffusionXL
80
+
81
+ >>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
82
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
83
+ ... )
84
+ >>> pipe = pipe.to("cuda")
85
+
86
+
87
+ >>> def download_image(url):
88
+ ... response = requests.get(url)
89
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
90
+
91
+
92
+ >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
93
+ >>> image = download_image(img_url)
94
+
95
+ >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
96
+
97
+ >>> edited_image = pipe(
98
+ ... editing_prompt=["tennis ball", "tomato"],
99
+ ... reverse_editing_direction=[True, False],
100
+ ... edit_guidance_scale=[5.0, 10.0],
101
+ ... edit_threshold=[0.9, 0.85],
102
+ ... ).images[0]
103
+ ```
104
+ """
105
+
106
+
107
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsAttentionStore
108
+ class LeditsAttentionStore:
109
+ @staticmethod
110
+ def get_empty_store():
111
+ return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []}
112
+
113
+ def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False):
114
+ # attn.shape = batch_size * head_size, seq_len query, seq_len_key
115
+ if attn.shape[1] <= self.max_size:
116
+ bs = 1 + int(PnP) + editing_prompts
117
+ skip = 2 if PnP else 1 # skip PnP & unconditional
118
+ attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3)
119
+ source_batch_size = int(attn.shape[1] // bs)
120
+ self.forward(attn[:, skip * source_batch_size :], is_cross, place_in_unet)
121
+
122
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
123
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
124
+
125
+ self.step_store[key].append(attn)
126
+
127
+ def between_steps(self, store_step=True):
128
+ if store_step:
129
+ if self.average:
130
+ if len(self.attention_store) == 0:
131
+ self.attention_store = self.step_store
132
+ else:
133
+ for key in self.attention_store:
134
+ for i in range(len(self.attention_store[key])):
135
+ self.attention_store[key][i] += self.step_store[key][i]
136
+ else:
137
+ if len(self.attention_store) == 0:
138
+ self.attention_store = [self.step_store]
139
+ else:
140
+ self.attention_store.append(self.step_store)
141
+
142
+ self.cur_step += 1
143
+ self.step_store = self.get_empty_store()
144
+
145
+ def get_attention(self, step: int):
146
+ if self.average:
147
+ attention = {
148
+ key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store
149
+ }
150
+ else:
151
+ assert step is not None
152
+ attention = self.attention_store[step]
153
+ return attention
154
+
155
+ def aggregate_attention(
156
+ self, attention_maps, prompts, res: Union[int, Tuple[int]], from_where: List[str], is_cross: bool, select: int
157
+ ):
158
+ out = [[] for x in range(self.batch_size)]
159
+ if isinstance(res, int):
160
+ num_pixels = res**2
161
+ resolution = (res, res)
162
+ else:
163
+ num_pixels = res[0] * res[1]
164
+ resolution = res[:2]
165
+
166
+ for location in from_where:
167
+ for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
168
+ for batch, item in enumerate(bs_item):
169
+ if item.shape[1] == num_pixels:
170
+ cross_maps = item.reshape(len(prompts), -1, *resolution, item.shape[-1])[select]
171
+ out[batch].append(cross_maps)
172
+
173
+ out = torch.stack([torch.cat(x, dim=0) for x in out])
174
+ # average over heads
175
+ out = out.sum(1) / out.shape[1]
176
+ return out
177
+
178
+ def __init__(self, average: bool, batch_size=1, max_resolution=16, max_size: int = None):
179
+ self.step_store = self.get_empty_store()
180
+ self.attention_store = []
181
+ self.cur_step = 0
182
+ self.average = average
183
+ self.batch_size = batch_size
184
+ if max_size is None:
185
+ self.max_size = max_resolution**2
186
+ elif max_size is not None and max_resolution is None:
187
+ self.max_size = max_size
188
+ else:
189
+ raise ValueError("Only allowed to set one of max_resolution or max_size")
190
+
191
+
192
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsGaussianSmoothing
193
+ class LeditsGaussianSmoothing:
194
+ def __init__(self, device):
195
+ kernel_size = [3, 3]
196
+ sigma = [0.5, 0.5]
197
+
198
+ # The gaussian kernel is the product of the gaussian function of each dimension.
199
+ kernel = 1
200
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
201
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
202
+ mean = (size - 1) / 2
203
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
204
+
205
+ # Make sure sum of values in gaussian kernel equals 1.
206
+ kernel = kernel / torch.sum(kernel)
207
+
208
+ # Reshape to depthwise convolutional weight
209
+ kernel = kernel.view(1, 1, *kernel.size())
210
+ kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1))
211
+
212
+ self.weight = kernel.to(device)
213
+
214
+ def __call__(self, input):
215
+ """
216
+ Arguments:
217
+ Apply gaussian filter to input.
218
+ input (torch.Tensor): Input to apply gaussian filter on.
219
+ Returns:
220
+ filtered (torch.Tensor): Filtered output.
221
+ """
222
+ return F.conv2d(input, weight=self.weight.to(input.dtype))
223
+
224
+
225
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEDITSCrossAttnProcessor
226
+ class LEDITSCrossAttnProcessor:
227
+ def __init__(self, attention_store, place_in_unet, pnp, editing_prompts):
228
+ self.attnstore = attention_store
229
+ self.place_in_unet = place_in_unet
230
+ self.editing_prompts = editing_prompts
231
+ self.pnp = pnp
232
+
233
+ def __call__(
234
+ self,
235
+ attn: Attention,
236
+ hidden_states,
237
+ encoder_hidden_states,
238
+ attention_mask=None,
239
+ temb=None,
240
+ ):
241
+ batch_size, sequence_length, _ = (
242
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
243
+ )
244
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
245
+
246
+ query = attn.to_q(hidden_states)
247
+
248
+ if encoder_hidden_states is None:
249
+ encoder_hidden_states = hidden_states
250
+ elif attn.norm_cross:
251
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
252
+
253
+ key = attn.to_k(encoder_hidden_states)
254
+ value = attn.to_v(encoder_hidden_states)
255
+
256
+ query = attn.head_to_batch_dim(query)
257
+ key = attn.head_to_batch_dim(key)
258
+ value = attn.head_to_batch_dim(value)
259
+
260
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
261
+ self.attnstore(
262
+ attention_probs,
263
+ is_cross=True,
264
+ place_in_unet=self.place_in_unet,
265
+ editing_prompts=self.editing_prompts,
266
+ PnP=self.pnp,
267
+ )
268
+
269
+ hidden_states = torch.bmm(attention_probs, value)
270
+ hidden_states = attn.batch_to_head_dim(hidden_states)
271
+
272
+ # linear proj
273
+ hidden_states = attn.to_out[0](hidden_states)
274
+ # dropout
275
+ hidden_states = attn.to_out[1](hidden_states)
276
+
277
+ hidden_states = hidden_states / attn.rescale_output_factor
278
+ return hidden_states
279
+
280
+
281
+ class LEditsPPPipelineStableDiffusionXL(
282
+ DiffusionPipeline,
283
+ FromSingleFileMixin,
284
+ StableDiffusionXLLoraLoaderMixin,
285
+ TextualInversionLoaderMixin,
286
+ IPAdapterMixin,
287
+ ):
288
+ """
289
+ Pipeline for textual image editing using LEDits++ with Stable Diffusion XL.
290
+
291
+ This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionXLPipeline`]. Check the
292
+ superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a
293
+ particular device, etc.).
294
+
295
+ In addition the pipeline inherits the following loading methods:
296
+ - *LoRA*: [`LEditsPPPipelineStableDiffusionXL.load_lora_weights`]
297
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
298
+
299
+ as well as the following saving methods:
300
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
301
+
302
+ Args:
303
+ vae ([`AutoencoderKL`]):
304
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
305
+ text_encoder ([`~transformers.CLIPTextModel`]):
306
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
307
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
308
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
309
+ text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
310
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
311
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
312
+ specifically the
313
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
314
+ variant.
315
+ tokenizer ([`~transformers.CLIPTokenizer`]):
316
+ Tokenizer of class
317
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
318
+ tokenizer_2 ([`~transformers.CLIPTokenizer`]):
319
+ Second Tokenizer of class
320
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
321
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
322
+ scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]):
323
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
324
+ [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will
325
+ automatically be set to [`DPMSolverMultistepScheduler`].
326
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
327
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
328
+ `stabilityai/stable-diffusion-xl-base-1-0`.
329
+ add_watermarker (`bool`, *optional*):
330
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
331
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
332
+ watermarker will be used.
333
+ """
334
+
335
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
336
+ _optional_components = [
337
+ "tokenizer",
338
+ "tokenizer_2",
339
+ "text_encoder",
340
+ "text_encoder_2",
341
+ "image_encoder",
342
+ "feature_extractor",
343
+ ]
344
+ _callback_tensor_inputs = [
345
+ "latents",
346
+ "prompt_embeds",
347
+ "negative_prompt_embeds",
348
+ "add_text_embeds",
349
+ "add_time_ids",
350
+ "negative_pooled_prompt_embeds",
351
+ "negative_add_time_ids",
352
+ ]
353
+
354
+ def __init__(
355
+ self,
356
+ vae: AutoencoderKL,
357
+ text_encoder: CLIPTextModel,
358
+ text_encoder_2: CLIPTextModelWithProjection,
359
+ tokenizer: CLIPTokenizer,
360
+ tokenizer_2: CLIPTokenizer,
361
+ unet: UNet2DConditionModel,
362
+ scheduler: Union[DPMSolverMultistepScheduler, DDIMScheduler],
363
+ image_encoder: CLIPVisionModelWithProjection = None,
364
+ feature_extractor: CLIPImageProcessor = None,
365
+ force_zeros_for_empty_prompt: bool = True,
366
+ add_watermarker: Optional[bool] = None,
367
+ ):
368
+ super().__init__()
369
+
370
+ self.register_modules(
371
+ vae=vae,
372
+ text_encoder=text_encoder,
373
+ text_encoder_2=text_encoder_2,
374
+ tokenizer=tokenizer,
375
+ tokenizer_2=tokenizer_2,
376
+ unet=unet,
377
+ scheduler=scheduler,
378
+ image_encoder=image_encoder,
379
+ feature_extractor=feature_extractor,
380
+ )
381
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
382
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
383
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
384
+
385
+ if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler):
386
+ self.scheduler = DPMSolverMultistepScheduler.from_config(
387
+ scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2
388
+ )
389
+ logger.warning(
390
+ "This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. "
391
+ "The scheduler has been changed to DPMSolverMultistepScheduler."
392
+ )
393
+
394
+ self.default_sample_size = self.unet.config.sample_size
395
+
396
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
397
+
398
+ if add_watermarker:
399
+ self.watermark = StableDiffusionXLWatermarker()
400
+ else:
401
+ self.watermark = None
402
+ self.inversion_steps = None
403
+
404
+ def encode_prompt(
405
+ self,
406
+ device: Optional[torch.device] = None,
407
+ num_images_per_prompt: int = 1,
408
+ negative_prompt: Optional[str] = None,
409
+ negative_prompt_2: Optional[str] = None,
410
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
411
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
412
+ lora_scale: Optional[float] = None,
413
+ clip_skip: Optional[int] = None,
414
+ enable_edit_guidance: bool = True,
415
+ editing_prompt: Optional[str] = None,
416
+ editing_prompt_embeds: Optional[torch.Tensor] = None,
417
+ editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
418
+ avg_diff = None,
419
+ avg_diff_2 = None,
420
+ correlation_weight_factor = 0.7,
421
+ scale=2,
422
+ ) -> object:
423
+ r"""
424
+ Encodes the prompt into text encoder hidden states.
425
+
426
+ Args:
427
+ device: (`torch.device`):
428
+ torch device
429
+ num_images_per_prompt (`int`):
430
+ number of images that should be generated per prompt
431
+ negative_prompt (`str` or `List[str]`, *optional*):
432
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
433
+ `negative_prompt_embeds` instead.
434
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
435
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
436
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
437
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
438
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
439
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
440
+ argument.
441
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
442
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
443
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
444
+ input argument.
445
+ lora_scale (`float`, *optional*):
446
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
447
+ clip_skip (`int`, *optional*):
448
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
449
+ the output of the pre-final layer will be used for computing the prompt embeddings.
450
+ enable_edit_guidance (`bool`):
451
+ Whether to guide towards an editing prompt or not.
452
+ editing_prompt (`str` or `List[str]`, *optional*):
453
+ Editing prompt(s) to be encoded. If not defined and 'enable_edit_guidance' is True, one has to pass
454
+ `editing_prompt_embeds` instead.
455
+ editing_prompt_embeds (`torch.Tensor`, *optional*):
456
+ Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
457
+ If not provided and 'enable_edit_guidance' is True, editing_prompt_embeds will be generated from
458
+ `editing_prompt` input argument.
459
+ editing_pooled_prompt_embeds (`torch.Tensor`, *optional*):
460
+ Pre-generated edit pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
461
+ weighting. If not provided, pooled editing_pooled_prompt_embeds will be generated from `editing_prompt`
462
+ input argument.
463
+ """
464
+ device = device or self._execution_device
465
+
466
+ # set lora scale so that monkey patched LoRA
467
+ # function of text encoder can correctly access it
468
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
469
+ self._lora_scale = lora_scale
470
+
471
+ # dynamically adjust the LoRA scale
472
+ if self.text_encoder is not None:
473
+ if not USE_PEFT_BACKEND:
474
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
475
+ else:
476
+ scale_lora_layers(self.text_encoder, lora_scale)
477
+
478
+ if self.text_encoder_2 is not None:
479
+ if not USE_PEFT_BACKEND:
480
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
481
+ else:
482
+ scale_lora_layers(self.text_encoder_2, lora_scale)
483
+
484
+ batch_size = self.batch_size
485
+
486
+ # Define tokenizers and text encoders
487
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
488
+ text_encoders = (
489
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
490
+ )
491
+ num_edit_tokens = 0
492
+
493
+ # get unconditional embeddings for classifier free guidance
494
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
495
+
496
+ if negative_prompt_embeds is None:
497
+ negative_prompt = negative_prompt or ""
498
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
499
+
500
+ # normalize str to list
501
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
502
+ negative_prompt_2 = (
503
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
504
+ )
505
+
506
+ uncond_tokens: List[str]
507
+
508
+ if batch_size != len(negative_prompt):
509
+ raise ValueError(
510
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but image inversion "
511
+ f" has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
512
+ " the batch size of the input images."
513
+ )
514
+ else:
515
+ uncond_tokens = [negative_prompt, negative_prompt_2]
516
+
517
+ j=0
518
+ negative_prompt_embeds_list = []
519
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
520
+ if isinstance(self, TextualInversionLoaderMixin):
521
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
522
+
523
+
524
+ uncond_input = tokenizer(
525
+ negative_prompt,
526
+ padding="max_length",
527
+ max_length=tokenizer.model_max_length,
528
+ truncation=True,
529
+ return_tensors="pt",
530
+ )
531
+ toks = uncond_input.input_ids
532
+
533
+ negative_prompt_embeds = text_encoder(
534
+ uncond_input.input_ids.to(device),
535
+ output_hidden_states=True,
536
+ )
537
+ # We are only ALWAYS interested in the pooled output of the final text encoder
538
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
539
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
540
+
541
+ if avg_diff is not None and avg_diff_2 is not None:
542
+ #scale=3
543
+ print("SHALOM neg")
544
+ normed_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
545
+ sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
546
+ if j == 0:
547
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
548
+
549
+ standard_weights = torch.ones_like(weights)
550
+
551
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
552
+ edit_concepts_embeds = negative_prompt_embeds + (weights * avg_diff[None, :].repeat(1,tokenizer.model_max_length, 1) * scale)
553
+ else:
554
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
555
+
556
+ standard_weights = torch.ones_like(weights)
557
+
558
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
559
+ edit_concepts_embeds = negative_prompt_embeds + (weights * avg_diff_2[None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
560
+
561
+
562
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
563
+ j+=1
564
+
565
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
566
+
567
+ if zero_out_negative_prompt:
568
+ negative_prompt_embeds = torch.zeros_like(negative_prompt_embeds)
569
+ negative_pooled_prompt_embeds = torch.zeros_like(negative_pooled_prompt_embeds)
570
+
571
+ if enable_edit_guidance and editing_prompt_embeds is None:
572
+ editing_prompt_2 = editing_prompt
573
+
574
+ editing_prompts = [editing_prompt, editing_prompt_2]
575
+ edit_prompt_embeds_list = []
576
+
577
+ i = 0
578
+ for editing_prompt, tokenizer, text_encoder in zip(editing_prompts, tokenizers, text_encoders):
579
+ if isinstance(self, TextualInversionLoaderMixin):
580
+ editing_prompt = self.maybe_convert_prompt(editing_prompt, tokenizer)
581
+
582
+ max_length = negative_prompt_embeds.shape[1]
583
+ edit_concepts_input = tokenizer(
584
+ # [x for item in editing_prompt for x in repeat(item, batch_size)],
585
+ editing_prompt,
586
+ padding="max_length",
587
+ max_length=max_length,
588
+ truncation=True,
589
+ return_tensors="pt",
590
+ return_length=True,
591
+ )
592
+ num_edit_tokens = edit_concepts_input.length - 2
593
+ toks = edit_concepts_input.input_ids
594
+ edit_concepts_embeds = text_encoder(
595
+ edit_concepts_input.input_ids.to(device),
596
+ output_hidden_states=True,
597
+ )
598
+ # We are only ALWAYS interested in the pooled output of the final text encoder
599
+ editing_pooled_prompt_embeds = edit_concepts_embeds[0]
600
+ if clip_skip is None:
601
+ edit_concepts_embeds = edit_concepts_embeds.hidden_states[-2]
602
+ else:
603
+ # "2" because SDXL always indexes from the penultimate layer.
604
+ edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
605
+
606
+ print("SHALOM???")
607
+ if avg_diff is not None and avg_diff_2 is not None:
608
+ #scale=3
609
+ print("SHALOM")
610
+ normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
611
+ sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
612
+ if i == 0:
613
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
614
+
615
+ standard_weights = torch.ones_like(weights)
616
+
617
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
618
+ edit_concepts_embeds = edit_concepts_embeds + (weights * avg_diff[None, :].repeat(1,tokenizer.model_max_length, 1) * scale)
619
+ else:
620
+ weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
621
+
622
+ standard_weights = torch.ones_like(weights)
623
+
624
+ weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
625
+ edit_concepts_embeds = edit_concepts_embeds + (weights * avg_diff_2[None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
626
+
627
+ edit_prompt_embeds_list.append(edit_concepts_embeds)
628
+ i+=1
629
+
630
+ edit_concepts_embeds = torch.concat(edit_prompt_embeds_list, dim=-1)
631
+ elif not enable_edit_guidance:
632
+ edit_concepts_embeds = None
633
+ editing_pooled_prompt_embeds = None
634
+
635
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
636
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
637
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
638
+ seq_len = negative_prompt_embeds.shape[1]
639
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
640
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
641
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
642
+
643
+ if enable_edit_guidance:
644
+ bs_embed_edit, seq_len, _ = edit_concepts_embeds.shape
645
+ edit_concepts_embeds = edit_concepts_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
646
+ edit_concepts_embeds = edit_concepts_embeds.repeat(1, num_images_per_prompt, 1)
647
+ edit_concepts_embeds = edit_concepts_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len, -1)
648
+
649
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
650
+ bs_embed * num_images_per_prompt, -1
651
+ )
652
+
653
+ if enable_edit_guidance:
654
+ editing_pooled_prompt_embeds = editing_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
655
+ bs_embed_edit * num_images_per_prompt, -1
656
+ )
657
+
658
+ if self.text_encoder is not None:
659
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
660
+ # Retrieve the original scale by scaling back the LoRA layers
661
+ unscale_lora_layers(self.text_encoder, lora_scale)
662
+
663
+ if self.text_encoder_2 is not None:
664
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
665
+ # Retrieve the original scale by scaling back the LoRA layers
666
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
667
+
668
+ return (
669
+ negative_prompt_embeds,
670
+ edit_concepts_embeds,
671
+ negative_pooled_prompt_embeds,
672
+ editing_pooled_prompt_embeds,
673
+ num_edit_tokens,
674
+ )
675
+
676
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
677
+ def prepare_extra_step_kwargs(self, eta, generator=None):
678
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
679
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
680
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
681
+ # and should be between [0, 1]
682
+
683
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
684
+ extra_step_kwargs = {}
685
+ if accepts_eta:
686
+ extra_step_kwargs["eta"] = eta
687
+
688
+ # check if the scheduler accepts generator
689
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
690
+ if accepts_generator:
691
+ extra_step_kwargs["generator"] = generator
692
+ return extra_step_kwargs
693
+
694
+ def check_inputs(
695
+ self,
696
+ negative_prompt=None,
697
+ negative_prompt_2=None,
698
+ negative_prompt_embeds=None,
699
+ negative_pooled_prompt_embeds=None,
700
+ ):
701
+ if negative_prompt is not None and negative_prompt_embeds is not None:
702
+ raise ValueError(
703
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
704
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
705
+ )
706
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
707
+ raise ValueError(
708
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
709
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
710
+ )
711
+
712
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
713
+ raise ValueError(
714
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
715
+ )
716
+
717
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
718
+ def prepare_latents(self, device, latents):
719
+ latents = latents.to(device)
720
+
721
+ # scale the initial noise by the standard deviation required by the scheduler
722
+ latents = latents * self.scheduler.init_noise_sigma
723
+ return latents
724
+
725
+ def _get_add_time_ids(
726
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
727
+ ):
728
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
729
+
730
+ passed_add_embed_dim = (
731
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
732
+ )
733
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
734
+
735
+ if expected_add_embed_dim != passed_add_embed_dim:
736
+ raise ValueError(
737
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
738
+ )
739
+
740
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
741
+ return add_time_ids
742
+
743
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
744
+ def upcast_vae(self):
745
+ dtype = self.vae.dtype
746
+ self.vae.to(dtype=torch.float32)
747
+ use_torch_2_0_or_xformers = isinstance(
748
+ self.vae.decoder.mid_block.attentions[0].processor,
749
+ (
750
+ AttnProcessor2_0,
751
+ XFormersAttnProcessor,
752
+ ),
753
+ )
754
+ # if xformers or torch_2_0 is used attention block does not need
755
+ # to be in float32 which can save lots of memory
756
+ if use_torch_2_0_or_xformers:
757
+ self.vae.post_quant_conv.to(dtype)
758
+ self.vae.decoder.conv_in.to(dtype)
759
+ self.vae.decoder.mid_block.to(dtype)
760
+
761
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
762
+ def get_guidance_scale_embedding(
763
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
764
+ ) -> torch.Tensor:
765
+ """
766
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
767
+
768
+ Args:
769
+ w (`torch.Tensor`):
770
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
771
+ embedding_dim (`int`, *optional*, defaults to 512):
772
+ Dimension of the embeddings to generate.
773
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
774
+ Data type of the generated embeddings.
775
+
776
+ Returns:
777
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
778
+ """
779
+ assert len(w.shape) == 1
780
+ w = w * 1000.0
781
+
782
+ half_dim = embedding_dim // 2
783
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
784
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
785
+ emb = w.to(dtype)[:, None] * emb[None, :]
786
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
787
+ if embedding_dim % 2 == 1: # zero pad
788
+ emb = torch.nn.functional.pad(emb, (0, 1))
789
+ assert emb.shape == (w.shape[0], embedding_dim)
790
+ return emb
791
+
792
+ @property
793
+ def guidance_scale(self):
794
+ return self._guidance_scale
795
+
796
+ @property
797
+ def guidance_rescale(self):
798
+ return self._guidance_rescale
799
+
800
+ @property
801
+ def clip_skip(self):
802
+ return self._clip_skip
803
+
804
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
805
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
806
+ # corresponds to doing no classifier free guidance.
807
+ @property
808
+ def do_classifier_free_guidance(self):
809
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
810
+
811
+ @property
812
+ def cross_attention_kwargs(self):
813
+ return self._cross_attention_kwargs
814
+
815
+ @property
816
+ def denoising_end(self):
817
+ return self._denoising_end
818
+
819
+ @property
820
+ def num_timesteps(self):
821
+ return self._num_timesteps
822
+
823
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
824
+ def prepare_unet(self, attention_store, PnP: bool = False):
825
+ attn_procs = {}
826
+ for name in self.unet.attn_processors.keys():
827
+ if name.startswith("mid_block"):
828
+ place_in_unet = "mid"
829
+ elif name.startswith("up_blocks"):
830
+ place_in_unet = "up"
831
+ elif name.startswith("down_blocks"):
832
+ place_in_unet = "down"
833
+ else:
834
+ continue
835
+
836
+ if "attn2" in name and place_in_unet != "mid":
837
+ attn_procs[name] = LEDITSCrossAttnProcessor(
838
+ attention_store=attention_store,
839
+ place_in_unet=place_in_unet,
840
+ pnp=PnP,
841
+ editing_prompts=self.enabled_editing_prompts,
842
+ )
843
+ else:
844
+ attn_procs[name] = AttnProcessor()
845
+
846
+ self.unet.set_attn_processor(attn_procs)
847
+
848
+ @spaces.GPU
849
+ @torch.no_grad()
850
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
851
+ def __call__(
852
+ self,
853
+ denoising_end: Optional[float] = None,
854
+ negative_prompt: Optional[Union[str, List[str]]] = None,
855
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
856
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
857
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
858
+ ip_adapter_image: Optional[PipelineImageInput] = None,
859
+ output_type: Optional[str] = "pil",
860
+ return_dict: bool = True,
861
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
862
+ guidance_rescale: float = 0.0,
863
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
864
+ target_size: Optional[Tuple[int, int]] = None,
865
+ editing_prompt: Optional[Union[str, List[str]]] = None,
866
+ editing_prompt_embeddings: Optional[torch.Tensor] = None,
867
+ editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
868
+ reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
869
+ edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
870
+ edit_warmup_steps: Optional[Union[int, List[int]]] = 0,
871
+ edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
872
+ edit_threshold: Optional[Union[float, List[float]]] = 0.9,
873
+ sem_guidance: Optional[List[torch.Tensor]] = None,
874
+ use_cross_attn_mask: bool = False,
875
+ use_intersect_mask: bool = False,
876
+ user_mask: Optional[torch.Tensor] = None,
877
+ attn_store_steps: Optional[List[int]] = [],
878
+ store_averaged_over_steps: bool = True,
879
+ clip_skip: Optional[int] = None,
880
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
881
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
882
+ avg_diff = None,
883
+ avg_diff_2 = None,
884
+ correlation_weight_factor = 0.7,
885
+ scale=2,
886
+ init_latents: [torch.Tensor] = None,
887
+ zs: [torch.Tensor] = None,
888
+ **kwargs,
889
+ ):
890
+ r"""
891
+ The call function to the pipeline for editing. The
892
+ [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL.invert`] method has to be called beforehand. Edits
893
+ will always be performed for the last inverted image(s).
894
+
895
+ Args:
896
+ denoising_end (`float`, *optional*):
897
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
898
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
899
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
900
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
901
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
902
+ negative_prompt (`str` or `List[str]`, *optional*):
903
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
904
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
905
+ less than `1`).
906
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
907
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
908
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
909
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
910
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
911
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
912
+ argument.
913
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
914
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
915
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
916
+ input argument.
917
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
918
+ Optional image input to work with IP Adapters.
919
+ output_type (`str`, *optional*, defaults to `"pil"`):
920
+ The output format of the generate image. Choose between
921
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
922
+ return_dict (`bool`, *optional*, defaults to `True`):
923
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
924
+ of a plain tuple.
925
+ callback (`Callable`, *optional*):
926
+ A function that will be called every `callback_steps` steps during inference. The function will be
927
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
928
+ callback_steps (`int`, *optional*, defaults to 1):
929
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
930
+ called at every step.
931
+ cross_attention_kwargs (`dict`, *optional*):
932
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
933
+ `self.processor` in
934
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
935
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
936
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
937
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
938
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
939
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
940
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
941
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
942
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
943
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
944
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
945
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
946
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
947
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
948
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
949
+ editing_prompt (`str` or `List[str]`, *optional*):
950
+ The prompt or prompts to guide the image generation. The image is reconstructed by setting
951
+ `editing_prompt = None`. Guidance direction of prompt should be specified via
952
+ `reverse_editing_direction`.
953
+ editing_prompt_embeddings (`torch.Tensor`, *optional*):
954
+ Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
955
+ If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input argument.
956
+ editing_pooled_prompt_embeddings (`torch.Tensor`, *optional*):
957
+ Pre-generated pooled edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
958
+ weighting. If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input
959
+ argument.
960
+ reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
961
+ Whether the corresponding prompt in `editing_prompt` should be increased or decreased.
962
+ edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
963
+ Guidance scale for guiding the image generation. If provided as list values should correspond to
964
+ `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++
965
+ Paper](https://arxiv.org/abs/2301.12247).
966
+ edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
967
+ Number of diffusion steps (for each prompt) for which guidance is not applied.
968
+ edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
969
+ Number of diffusion steps (for each prompt) after which guidance is no longer applied.
970
+ edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
971
+ Masking threshold of guidance. Threshold should be proportional to the image region that is modified.
972
+ 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++
973
+ Paper](https://arxiv.org/abs/2301.12247).
974
+ sem_guidance (`List[torch.Tensor]`, *optional*):
975
+ List of pre-generated guidance vectors to be applied at generation. Length of the list has to
976
+ correspond to `num_inference_steps`.
977
+ use_cross_attn_mask:
978
+ Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask
979
+ is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++
980
+ paper](https://arxiv.org/pdf/2311.16711.pdf).
981
+ use_intersect_mask:
982
+ Whether the masking term is calculated as intersection of cross-attention masks and masks derived from
983
+ the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate
984
+ are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
985
+ user_mask:
986
+ User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s
987
+ implicit masks do not meet user preferences.
988
+ attn_store_steps:
989
+ Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes.
990
+ store_averaged_over_steps:
991
+ Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps. If
992
+ False, attention maps for each step are stores separately. Just for visualization purposes.
993
+ clip_skip (`int`, *optional*):
994
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
995
+ the output of the pre-final layer will be used for computing the prompt embeddings.
996
+ callback_on_step_end (`Callable`, *optional*):
997
+ A function that calls at the end of each denoising steps during the inference. The function is called
998
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
999
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1000
+ `callback_on_step_end_tensor_inputs`.
1001
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1002
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1003
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1004
+ `._callback_tensor_inputs` attribute of your pipeline class.
1005
+
1006
+ Examples:
1007
+
1008
+ Returns:
1009
+ [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`:
1010
+ [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
1011
+ returning a tuple, the first element is a list with the generated images.
1012
+ """
1013
+ if self.inversion_steps is None:
1014
+ raise ValueError(
1015
+ "You need to invert an input image first before calling the pipeline. The `invert` method has to be called beforehand. Edits will always be performed for the last inverted image(s)."
1016
+ )
1017
+
1018
+ eta = self.eta
1019
+ num_images_per_prompt = 1
1020
+ #latents = self.init_latents
1021
+ latents = init_latents
1022
+
1023
+ #zs = self.zs
1024
+ self.scheduler.set_timesteps(len(self.scheduler.timesteps))
1025
+
1026
+ if use_intersect_mask:
1027
+ use_cross_attn_mask = True
1028
+
1029
+ if use_cross_attn_mask:
1030
+ self.smoothing = LeditsGaussianSmoothing(self.device)
1031
+
1032
+ if user_mask is not None:
1033
+ user_mask = user_mask.to(self.device)
1034
+
1035
+ # TODO: Check inputs
1036
+ # 1. Check inputs. Raise error if not correct
1037
+ # self.check_inputs(
1038
+ # callback_steps,
1039
+ # negative_prompt,
1040
+ # negative_prompt_2,
1041
+ # prompt_embeds,
1042
+ # negative_prompt_embeds,
1043
+ # pooled_prompt_embeds,
1044
+ # negative_pooled_prompt_embeds,
1045
+ # )
1046
+ self._guidance_rescale = guidance_rescale
1047
+ self._clip_skip = clip_skip
1048
+ self._cross_attention_kwargs = cross_attention_kwargs
1049
+ self._denoising_end = denoising_end
1050
+
1051
+ # 2. Define call parameters
1052
+ batch_size = self.batch_size
1053
+
1054
+ device = self._execution_device
1055
+
1056
+ if editing_prompt:
1057
+ enable_edit_guidance = True
1058
+ if isinstance(editing_prompt, str):
1059
+ editing_prompt = [editing_prompt]
1060
+ self.enabled_editing_prompts = len(editing_prompt)
1061
+ elif editing_prompt_embeddings is not None:
1062
+ enable_edit_guidance = True
1063
+ self.enabled_editing_prompts = editing_prompt_embeddings.shape[0]
1064
+ else:
1065
+ self.enabled_editing_prompts = 0
1066
+ enable_edit_guidance = False
1067
+ print("negative_prompt", negative_prompt)
1068
+ # 3. Encode input prompt
1069
+ text_encoder_lora_scale = (
1070
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1071
+ )
1072
+ (
1073
+ prompt_embeds,
1074
+ edit_prompt_embeds,
1075
+ negative_pooled_prompt_embeds,
1076
+ pooled_edit_embeds,
1077
+ num_edit_tokens,
1078
+ ) = self.encode_prompt(
1079
+ device=device,
1080
+ num_images_per_prompt=num_images_per_prompt,
1081
+ negative_prompt=negative_prompt,
1082
+ negative_prompt_2=negative_prompt_2,
1083
+ negative_prompt_embeds=negative_prompt_embeds,
1084
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1085
+ lora_scale=text_encoder_lora_scale,
1086
+ clip_skip=self.clip_skip,
1087
+ enable_edit_guidance=enable_edit_guidance,
1088
+ editing_prompt=editing_prompt,
1089
+ editing_prompt_embeds=editing_prompt_embeddings,
1090
+ editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
1091
+ avg_diff = avg_diff,
1092
+ avg_diff_2 = avg_diff_2,
1093
+ correlation_weight_factor = correlation_weight_factor,
1094
+ scale=scale,
1095
+ )
1096
+
1097
+ # 4. Prepare timesteps
1098
+ # self.scheduler.set_timesteps(num_inference_steps, device=device)
1099
+
1100
+ timesteps = self.inversion_steps
1101
+ timesteps = inversion_steps
1102
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1103
+
1104
+ if use_cross_attn_mask:
1105
+ self.attention_store = LeditsAttentionStore(
1106
+ average=store_averaged_over_steps,
1107
+ batch_size=batch_size,
1108
+ max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0),
1109
+ max_resolution=None,
1110
+ )
1111
+ self.prepare_unet(self.attention_store)
1112
+ resolution = latents.shape[-2:]
1113
+ att_res = (int(resolution[0] / 4), int(resolution[1] / 4))
1114
+
1115
+ # 5. Prepare latent variables
1116
+ latents = self.prepare_latents(device=device, latents=latents)
1117
+
1118
+ # 6. Prepare extra step kwargs.
1119
+ extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
1120
+
1121
+ if self.text_encoder_2 is None:
1122
+ text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1])
1123
+ else:
1124
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1125
+
1126
+ # 7. Prepare added time ids & embeddings
1127
+ add_text_embeds = negative_pooled_prompt_embeds
1128
+ add_time_ids = self._get_add_time_ids(
1129
+ self.size,
1130
+ crops_coords_top_left,
1131
+ self.size,
1132
+ dtype=negative_pooled_prompt_embeds.dtype,
1133
+ text_encoder_projection_dim=text_encoder_projection_dim,
1134
+ )
1135
+
1136
+ if enable_edit_guidance:
1137
+ prompt_embeds = torch.cat([prompt_embeds, edit_prompt_embeds], dim=0)
1138
+ add_text_embeds = torch.cat([add_text_embeds, pooled_edit_embeds], dim=0)
1139
+ edit_concepts_time_ids = add_time_ids.repeat(edit_prompt_embeds.shape[0], 1)
1140
+ add_time_ids = torch.cat([add_time_ids, edit_concepts_time_ids], dim=0)
1141
+ self.text_cross_attention_maps = [editing_prompt] if isinstance(editing_prompt, str) else editing_prompt
1142
+
1143
+ prompt_embeds = prompt_embeds.to(device)
1144
+ add_text_embeds = add_text_embeds.to(device)
1145
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1146
+
1147
+ if ip_adapter_image is not None:
1148
+ # TODO: fix image encoding
1149
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
1150
+ if self.do_classifier_free_guidance:
1151
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
1152
+ image_embeds = image_embeds.to(device)
1153
+
1154
+ # 8. Denoising loop
1155
+ self.sem_guidance = None
1156
+ self.activation_mask = None
1157
+
1158
+ if (
1159
+ self.denoising_end is not None
1160
+ and isinstance(self.denoising_end, float)
1161
+ and self.denoising_end > 0
1162
+ and self.denoising_end < 1
1163
+ ):
1164
+ discrete_timestep_cutoff = int(
1165
+ round(
1166
+ self.scheduler.config.num_train_timesteps
1167
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1168
+ )
1169
+ )
1170
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1171
+ timesteps = timesteps[:num_inference_steps]
1172
+
1173
+ # 9. Optionally get Guidance Scale Embedding
1174
+ timestep_cond = None
1175
+ if self.unet.config.time_cond_proj_dim is not None:
1176
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1177
+ timestep_cond = self.get_guidance_scale_embedding(
1178
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1179
+ ).to(device=device, dtype=latents.dtype)
1180
+
1181
+ self._num_timesteps = len(timesteps)
1182
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
1183
+ for i, t in enumerate(timesteps):
1184
+ # expand the latents if we are doing classifier free guidance
1185
+ latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts))
1186
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1187
+ # predict the noise residual
1188
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1189
+ if ip_adapter_image is not None:
1190
+ added_cond_kwargs["image_embeds"] = image_embeds
1191
+ noise_pred = self.unet(
1192
+ latent_model_input,
1193
+ t,
1194
+ encoder_hidden_states=prompt_embeds,
1195
+ cross_attention_kwargs=cross_attention_kwargs,
1196
+ added_cond_kwargs=added_cond_kwargs,
1197
+ return_dict=False,
1198
+ )[0]
1199
+
1200
+ noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64]
1201
+ noise_pred_uncond = noise_pred_out[0]
1202
+ noise_pred_edit_concepts = noise_pred_out[1:]
1203
+
1204
+ noise_guidance_edit = torch.zeros(
1205
+ noise_pred_uncond.shape,
1206
+ device=self.device,
1207
+ dtype=noise_pred_uncond.dtype,
1208
+ )
1209
+
1210
+ if sem_guidance is not None and len(sem_guidance) > i:
1211
+ noise_guidance_edit += sem_guidance[i].to(self.device)
1212
+
1213
+ elif enable_edit_guidance:
1214
+ if self.activation_mask is None:
1215
+ self.activation_mask = torch.zeros(
1216
+ (len(timesteps), self.enabled_editing_prompts, *noise_pred_edit_concepts[0].shape)
1217
+ )
1218
+ if self.sem_guidance is None:
1219
+ self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape))
1220
+
1221
+ # noise_guidance_edit = torch.zeros_like(noise_guidance)
1222
+ for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
1223
+ if isinstance(edit_warmup_steps, list):
1224
+ edit_warmup_steps_c = edit_warmup_steps[c]
1225
+ else:
1226
+ edit_warmup_steps_c = edit_warmup_steps
1227
+ if i < edit_warmup_steps_c:
1228
+ continue
1229
+
1230
+ if isinstance(edit_guidance_scale, list):
1231
+ edit_guidance_scale_c = edit_guidance_scale[c]
1232
+ else:
1233
+ edit_guidance_scale_c = edit_guidance_scale
1234
+
1235
+ if isinstance(edit_threshold, list):
1236
+ edit_threshold_c = edit_threshold[c]
1237
+ else:
1238
+ edit_threshold_c = edit_threshold
1239
+ if isinstance(reverse_editing_direction, list):
1240
+ reverse_editing_direction_c = reverse_editing_direction[c]
1241
+ else:
1242
+ reverse_editing_direction_c = reverse_editing_direction
1243
+
1244
+ if isinstance(edit_cooldown_steps, list):
1245
+ edit_cooldown_steps_c = edit_cooldown_steps[c]
1246
+ elif edit_cooldown_steps is None:
1247
+ edit_cooldown_steps_c = i + 1
1248
+ else:
1249
+ edit_cooldown_steps_c = edit_cooldown_steps
1250
+
1251
+ if i >= edit_cooldown_steps_c:
1252
+ continue
1253
+
1254
+ noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
1255
+
1256
+ if reverse_editing_direction_c:
1257
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
1258
+
1259
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
1260
+
1261
+ if user_mask is not None:
1262
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
1263
+
1264
+ if use_cross_attn_mask:
1265
+ out = self.attention_store.aggregate_attention(
1266
+ attention_maps=self.attention_store.step_store,
1267
+ prompts=self.text_cross_attention_maps,
1268
+ res=att_res,
1269
+ from_where=["up", "down"],
1270
+ is_cross=True,
1271
+ select=self.text_cross_attention_maps.index(editing_prompt[c]),
1272
+ )
1273
+ attn_map = out[:, :, :, 1 : 1 + num_edit_tokens[c]] # 0 -> startoftext
1274
+
1275
+ # average over all tokens
1276
+ if attn_map.shape[3] != num_edit_tokens[c]:
1277
+ raise ValueError(
1278
+ f"Incorrect shape of attention_map. Expected size {num_edit_tokens[c]}, but found {attn_map.shape[3]}!"
1279
+ )
1280
+ attn_map = torch.sum(attn_map, dim=3)
1281
+
1282
+ # gaussian_smoothing
1283
+ attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect")
1284
+ attn_map = self.smoothing(attn_map).squeeze(1)
1285
+
1286
+ # torch.quantile function expects float32
1287
+ if attn_map.dtype == torch.float32:
1288
+ tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1)
1289
+ else:
1290
+ tmp = torch.quantile(
1291
+ attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1
1292
+ ).to(attn_map.dtype)
1293
+ attn_mask = torch.where(
1294
+ attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1, *att_res), 1.0, 0.0
1295
+ )
1296
+
1297
+ # resolution must match latent space dimension
1298
+ attn_mask = F.interpolate(
1299
+ attn_mask.unsqueeze(1),
1300
+ noise_guidance_edit_tmp.shape[-2:], # 64,64
1301
+ ).repeat(1, 4, 1, 1)
1302
+ self.activation_mask[i, c] = attn_mask.detach().cpu()
1303
+ if not use_intersect_mask:
1304
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
1305
+
1306
+ if use_intersect_mask:
1307
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
1308
+ noise_guidance_edit_tmp_quantile = torch.sum(
1309
+ noise_guidance_edit_tmp_quantile, dim=1, keepdim=True
1310
+ )
1311
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(
1312
+ 1, self.unet.config.in_channels, 1, 1
1313
+ )
1314
+
1315
+ # torch.quantile function expects float32
1316
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
1317
+ tmp = torch.quantile(
1318
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
1319
+ edit_threshold_c,
1320
+ dim=2,
1321
+ keepdim=False,
1322
+ )
1323
+ else:
1324
+ tmp = torch.quantile(
1325
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
1326
+ edit_threshold_c,
1327
+ dim=2,
1328
+ keepdim=False,
1329
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
1330
+
1331
+ intersect_mask = (
1332
+ torch.where(
1333
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1334
+ torch.ones_like(noise_guidance_edit_tmp),
1335
+ torch.zeros_like(noise_guidance_edit_tmp),
1336
+ )
1337
+ * attn_mask
1338
+ )
1339
+
1340
+ self.activation_mask[i, c] = intersect_mask.detach().cpu()
1341
+
1342
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask
1343
+
1344
+ elif not use_cross_attn_mask:
1345
+ # calculate quantile
1346
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
1347
+ noise_guidance_edit_tmp_quantile = torch.sum(
1348
+ noise_guidance_edit_tmp_quantile, dim=1, keepdim=True
1349
+ )
1350
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
1351
+
1352
+ # torch.quantile function expects float32
1353
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
1354
+ tmp = torch.quantile(
1355
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
1356
+ edit_threshold_c,
1357
+ dim=2,
1358
+ keepdim=False,
1359
+ )
1360
+ else:
1361
+ tmp = torch.quantile(
1362
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
1363
+ edit_threshold_c,
1364
+ dim=2,
1365
+ keepdim=False,
1366
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
1367
+
1368
+ self.activation_mask[i, c] = (
1369
+ torch.where(
1370
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1371
+ torch.ones_like(noise_guidance_edit_tmp),
1372
+ torch.zeros_like(noise_guidance_edit_tmp),
1373
+ )
1374
+ .detach()
1375
+ .cpu()
1376
+ )
1377
+
1378
+ noise_guidance_edit_tmp = torch.where(
1379
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1380
+ noise_guidance_edit_tmp,
1381
+ torch.zeros_like(noise_guidance_edit_tmp),
1382
+ )
1383
+
1384
+ noise_guidance_edit += noise_guidance_edit_tmp
1385
+
1386
+ self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
1387
+
1388
+ noise_pred = noise_pred_uncond + noise_guidance_edit
1389
+
1390
+ # compute the previous noisy sample x_t -> x_t-1
1391
+ if enable_edit_guidance and self.guidance_rescale > 0.0:
1392
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1393
+ noise_pred = rescale_noise_cfg(
1394
+ noise_pred,
1395
+ noise_pred_edit_concepts.mean(dim=0, keepdim=False),
1396
+ guidance_rescale=self.guidance_rescale,
1397
+ )
1398
+
1399
+ idx = t_to_idx[int(t)]
1400
+ latents = self.scheduler.step(
1401
+ noise_pred, t, latents, variance_noise=zs[idx], **extra_step_kwargs, return_dict=False
1402
+ )[0]
1403
+
1404
+ # step callback
1405
+ if use_cross_attn_mask:
1406
+ store_step = i in attn_store_steps
1407
+ self.attention_store.between_steps(store_step)
1408
+
1409
+ if callback_on_step_end is not None:
1410
+ callback_kwargs = {}
1411
+ for k in callback_on_step_end_tensor_inputs:
1412
+ callback_kwargs[k] = locals()[k]
1413
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1414
+
1415
+ latents = callback_outputs.pop("latents", latents)
1416
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1417
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1418
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1419
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1420
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1421
+ )
1422
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1423
+ # negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1424
+
1425
+ # call the callback, if provided
1426
+ if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % self.scheduler.order == 0):
1427
+ progress_bar.update()
1428
+
1429
+ if XLA_AVAILABLE:
1430
+ xm.mark_step()
1431
+
1432
+ if not output_type == "latent":
1433
+ # make sure the VAE is in float32 mode, as it overflows in float16
1434
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1435
+
1436
+ if needs_upcasting:
1437
+ self.upcast_vae()
1438
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1439
+
1440
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1441
+
1442
+ # cast back to fp16 if needed
1443
+ if needs_upcasting:
1444
+ self.vae.to(dtype=torch.float16)
1445
+ else:
1446
+ image = latents
1447
+
1448
+ if not output_type == "latent":
1449
+ # apply watermark if available
1450
+ if self.watermark is not None:
1451
+ image = self.watermark.apply_watermark(image)
1452
+
1453
+ image = self.image_processor.postprocess(image, output_type=output_type)
1454
+
1455
+ # Offload all models
1456
+ self.maybe_free_model_hooks()
1457
+
1458
+ if not return_dict:
1459
+ return (image,)
1460
+
1461
+ return LEditsPPDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
1462
+
1463
+ @torch.no_grad()
1464
+ # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image
1465
+ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None):
1466
+ image = self.image_processor.preprocess(
1467
+ image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1468
+ )
1469
+ resized = self.image_processor.postprocess(image=image, output_type="pil")
1470
+
1471
+ if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
1472
+ logger.warning(
1473
+ "Your input images far exceed the default resolution of the underlying diffusion model. "
1474
+ "The output images may contain severe artifacts! "
1475
+ "Consider down-sampling the input using the `height` and `width` parameters"
1476
+ )
1477
+ image = image.to(self.device, dtype=dtype)
1478
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1479
+
1480
+ if needs_upcasting:
1481
+ image = image.float()
1482
+ self.upcast_vae()
1483
+
1484
+ x0 = self.vae.encode(image).latent_dist.mode()
1485
+ x0 = x0.to(dtype)
1486
+ # cast back to fp16 if needed
1487
+ if needs_upcasting:
1488
+ self.vae.to(dtype=torch.float16)
1489
+
1490
+ x0 = self.vae.config.scaling_factor * x0
1491
+ return x0, resized
1492
+
1493
+ @spaces.GPU
1494
+ @torch.no_grad()
1495
+ def invert(
1496
+ self,
1497
+ image: PipelineImageInput,
1498
+ source_prompt: str = "",
1499
+ source_guidance_scale=3.5,
1500
+ negative_prompt: str = None,
1501
+ negative_prompt_2: str = None,
1502
+ num_inversion_steps: int = 50,
1503
+ skip: float = 0.15,
1504
+ generator: Optional[torch.Generator] = None,
1505
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1506
+ num_zero_noise_steps: int = 3,
1507
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1508
+ ):
1509
+ r"""
1510
+ The function to the pipeline for image inversion as described by the [LEDITS++
1511
+ Paper](https://arxiv.org/abs/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the
1512
+ inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140) will be performed instead.
1513
+
1514
+ Args:
1515
+ image (`PipelineImageInput`):
1516
+ Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect
1517
+ ratio.
1518
+ source_prompt (`str`, defaults to `""`):
1519
+ Prompt describing the input image that will be used for guidance during inversion. Guidance is disabled
1520
+ if the `source_prompt` is `""`.
1521
+ source_guidance_scale (`float`, defaults to `3.5`):
1522
+ Strength of guidance during inversion.
1523
+ negative_prompt (`str` or `List[str]`, *optional*):
1524
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1525
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1526
+ less than `1`).
1527
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1528
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1529
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1530
+ num_inversion_steps (`int`, defaults to `50`):
1531
+ Number of total performed inversion steps after discarding the initial `skip` steps.
1532
+ skip (`float`, defaults to `0.15`):
1533
+ Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values
1534
+ will lead to stronger changes to the input image. `skip` has to be between `0` and `1`.
1535
+ generator (`torch.Generator`, *optional*):
1536
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make inversion
1537
+ deterministic.
1538
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1539
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1540
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1541
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1542
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1543
+ num_zero_noise_steps (`int`, defaults to `3`):
1544
+ Number of final diffusion steps that will not renoise the current image. If no steps are set to zero
1545
+ SD-XL in combination with [`DPMSolverMultistepScheduler`] will produce noise artifacts.
1546
+ cross_attention_kwargs (`dict`, *optional*):
1547
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1548
+ `self.processor` in
1549
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1550
+
1551
+ Returns:
1552
+ [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
1553
+ and respective VAE reconstruction(s).
1554
+ """
1555
+
1556
+ # Reset attn processor, we do not want to store attn maps during inversion
1557
+ self.unet.set_attn_processor(AttnProcessor())
1558
+
1559
+ self.eta = 1.0
1560
+
1561
+ self.scheduler.config.timestep_spacing = "leading"
1562
+ self.scheduler.set_timesteps(int(num_inversion_steps * (1 + skip)))
1563
+ self.inversion_steps = self.scheduler.timesteps[-num_inversion_steps:]
1564
+ timesteps = self.inversion_steps
1565
+
1566
+ num_images_per_prompt = 1
1567
+
1568
+ device = self._execution_device
1569
+
1570
+ # 0. Ensure that only uncond embedding is used if prompt = ""
1571
+ if source_prompt == "":
1572
+ # noise pred should only be noise_pred_uncond
1573
+ source_guidance_scale = 0.0
1574
+ do_classifier_free_guidance = False
1575
+ else:
1576
+ do_classifier_free_guidance = source_guidance_scale > 1.0
1577
+
1578
+ # 1. prepare image
1579
+ x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
1580
+ width = x0.shape[2] * self.vae_scale_factor
1581
+ height = x0.shape[3] * self.vae_scale_factor
1582
+ self.size = (height, width)
1583
+
1584
+ self.batch_size = x0.shape[0]
1585
+
1586
+ # 2. get embeddings
1587
+ text_encoder_lora_scale = (
1588
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1589
+ )
1590
+
1591
+ if isinstance(source_prompt, str):
1592
+ source_prompt = [source_prompt] * self.batch_size
1593
+
1594
+ (
1595
+ negative_prompt_embeds,
1596
+ prompt_embeds,
1597
+ negative_pooled_prompt_embeds,
1598
+ edit_pooled_prompt_embeds,
1599
+ _,
1600
+ ) = self.encode_prompt(
1601
+ device=device,
1602
+ num_images_per_prompt=num_images_per_prompt,
1603
+ negative_prompt=negative_prompt,
1604
+ negative_prompt_2=negative_prompt_2,
1605
+ editing_prompt=source_prompt,
1606
+ lora_scale=text_encoder_lora_scale,
1607
+ enable_edit_guidance=do_classifier_free_guidance,
1608
+ )
1609
+ if self.text_encoder_2 is None:
1610
+ text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1])
1611
+ else:
1612
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1613
+
1614
+ # 3. Prepare added time ids & embeddings
1615
+ add_text_embeds = negative_pooled_prompt_embeds
1616
+ add_time_ids = self._get_add_time_ids(
1617
+ self.size,
1618
+ crops_coords_top_left,
1619
+ self.size,
1620
+ dtype=negative_prompt_embeds.dtype,
1621
+ text_encoder_projection_dim=text_encoder_projection_dim,
1622
+ )
1623
+
1624
+ if do_classifier_free_guidance:
1625
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1626
+ add_text_embeds = torch.cat([add_text_embeds, edit_pooled_prompt_embeds], dim=0)
1627
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
1628
+
1629
+ negative_prompt_embeds = negative_prompt_embeds.to(device)
1630
+
1631
+ add_text_embeds = add_text_embeds.to(device)
1632
+ add_time_ids = add_time_ids.to(device).repeat(self.batch_size * num_images_per_prompt, 1)
1633
+
1634
+ # autoencoder reconstruction
1635
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1636
+ self.upcast_vae()
1637
+ x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1638
+ image_rec = self.vae.decode(
1639
+ x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator
1640
+ )[0]
1641
+ elif self.vae.config.force_upcast:
1642
+ x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1643
+ image_rec = self.vae.decode(
1644
+ x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator
1645
+ )[0]
1646
+ else:
1647
+ image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
1648
+
1649
+ image_rec = self.image_processor.postprocess(image_rec, output_type="pil")
1650
+
1651
+ # 5. find zs and xts
1652
+ variance_noise_shape = (num_inversion_steps, *x0.shape)
1653
+
1654
+ # intermediate latents
1655
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1656
+ xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype)
1657
+
1658
+ for t in reversed(timesteps):
1659
+ idx = num_inversion_steps - t_to_idx[int(t)] - 1
1660
+ noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype)
1661
+ xts[idx] = self.scheduler.add_noise(x0, noise, t.unsqueeze(0))
1662
+ xts = torch.cat([x0.unsqueeze(0), xts], dim=0)
1663
+
1664
+ # noise maps
1665
+ zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype)
1666
+
1667
+ self.scheduler.set_timesteps(len(self.scheduler.timesteps))
1668
+
1669
+ for t in self.progress_bar(timesteps):
1670
+ idx = num_inversion_steps - t_to_idx[int(t)] - 1
1671
+ # 1. predict noise residual
1672
+ xt = xts[idx + 1]
1673
+
1674
+ latent_model_input = torch.cat([xt] * 2) if do_classifier_free_guidance else xt
1675
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1676
+
1677
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1678
+
1679
+ noise_pred = self.unet(
1680
+ latent_model_input,
1681
+ t,
1682
+ encoder_hidden_states=negative_prompt_embeds,
1683
+ cross_attention_kwargs=cross_attention_kwargs,
1684
+ added_cond_kwargs=added_cond_kwargs,
1685
+ return_dict=False,
1686
+ )[0]
1687
+
1688
+ # 2. perform guidance
1689
+ if do_classifier_free_guidance:
1690
+ noise_pred_out = noise_pred.chunk(2)
1691
+ noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
1692
+ noise_pred = noise_pred_uncond + source_guidance_scale * (noise_pred_text - noise_pred_uncond)
1693
+
1694
+ xtm1 = xts[idx]
1695
+ z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, self.eta)
1696
+ zs[idx] = z
1697
+
1698
+ # correction to avoid error accumulation
1699
+ xts[idx] = xtm1_corrected
1700
+
1701
+ self.init_latents = xts[-1]
1702
+ zs = zs.flip(0)
1703
+
1704
+ if num_zero_noise_steps > 0:
1705
+ zs[-num_zero_noise_steps:] = torch.zeros_like(zs[-num_zero_noise_steps:])
1706
+ self.zs = zs
1707
+ #return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec)
1708
+ return xts[-1], zs
1709
+
1710
+
1711
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg
1712
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
1713
+ """
1714
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
1715
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
1716
+ """
1717
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
1718
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
1719
+ # rescale the results from guidance (fixes overexposure)
1720
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
1721
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
1722
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
1723
+ return noise_cfg
1724
+
1725
+
1726
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_ddim
1727
+ def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1728
+ # 1. get previous step value (=t-1)
1729
+ prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
1730
+
1731
+ # 2. compute alphas, betas
1732
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
1733
+ alpha_prod_t_prev = (
1734
+ scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
1735
+ )
1736
+
1737
+ beta_prod_t = 1 - alpha_prod_t
1738
+
1739
+ # 3. compute predicted original sample from predicted noise also called
1740
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
1741
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
1742
+
1743
+ # 4. Clip "predicted x_0"
1744
+ if scheduler.config.clip_sample:
1745
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
1746
+
1747
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
1748
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
1749
+ variance = scheduler._get_variance(timestep, prev_timestep)
1750
+ std_dev_t = eta * variance ** (0.5)
1751
+
1752
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
1753
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
1754
+
1755
+ # modifed so that updated xtm1 is returned as well (to avoid error accumulation)
1756
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
1757
+ if variance > 0.0:
1758
+ noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta)
1759
+ else:
1760
+ noise = torch.tensor([0.0]).to(latents.device)
1761
+
1762
+ return noise, mu_xt + (eta * variance**0.5) * noise
1763
+
1764
+
1765
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_sde_dpm_pp_2nd
1766
+ def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1767
+ def first_order_update(model_output, sample): # timestep, prev_timestep, sample):
1768
+ sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index]
1769
+ alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t)
1770
+ alpha_s, sigma_s = scheduler._sigma_to_alpha_sigma_t(sigma_s)
1771
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
1772
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
1773
+
1774
+ h = lambda_t - lambda_s
1775
+
1776
+ mu_xt = (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
1777
+
1778
+ mu_xt = scheduler.dpm_solver_first_order_update(
1779
+ model_output=model_output, sample=sample, noise=torch.zeros_like(sample)
1780
+ )
1781
+
1782
+ sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h))
1783
+ if sigma > 0.0:
1784
+ noise = (prev_latents - mu_xt) / sigma
1785
+ else:
1786
+ noise = torch.tensor([0.0]).to(sample.device)
1787
+
1788
+ prev_sample = mu_xt + sigma * noise
1789
+ return noise, prev_sample
1790
+
1791
+ def second_order_update(model_output_list, sample): # timestep_list, prev_timestep, sample):
1792
+ sigma_t, sigma_s0, sigma_s1 = (
1793
+ scheduler.sigmas[scheduler.step_index + 1],
1794
+ scheduler.sigmas[scheduler.step_index],
1795
+ scheduler.sigmas[scheduler.step_index - 1],
1796
+ )
1797
+
1798
+ alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t)
1799
+ alpha_s0, sigma_s0 = scheduler._sigma_to_alpha_sigma_t(sigma_s0)
1800
+ alpha_s1, sigma_s1 = scheduler._sigma_to_alpha_sigma_t(sigma_s1)
1801
+
1802
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
1803
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
1804
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
1805
+
1806
+ m0, m1 = model_output_list[-1], model_output_list[-2]
1807
+
1808
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
1809
+ r0 = h_0 / h
1810
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
1811
+
1812
+ mu_xt = (
1813
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
1814
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
1815
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
1816
+ )
1817
+
1818
+ sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h))
1819
+ if sigma > 0.0:
1820
+ noise = (prev_latents - mu_xt) / sigma
1821
+ else:
1822
+ noise = torch.tensor([0.0]).to(sample.device)
1823
+
1824
+ prev_sample = mu_xt + sigma * noise
1825
+
1826
+ return noise, prev_sample
1827
+
1828
+ if scheduler.step_index is None:
1829
+ scheduler._init_step_index(timestep)
1830
+
1831
+ model_output = scheduler.convert_model_output(model_output=noise_pred, sample=latents)
1832
+ for i in range(scheduler.config.solver_order - 1):
1833
+ scheduler.model_outputs[i] = scheduler.model_outputs[i + 1]
1834
+ scheduler.model_outputs[-1] = model_output
1835
+
1836
+ if scheduler.lower_order_nums < 1:
1837
+ noise, prev_sample = first_order_update(model_output, latents)
1838
+ else:
1839
+ noise, prev_sample = second_order_update(scheduler.model_outputs, latents)
1840
+
1841
+ if scheduler.lower_order_nums < scheduler.config.solver_order:
1842
+ scheduler.lower_order_nums += 1
1843
+
1844
+ # upon completion increase step index by one
1845
+ scheduler._step_index += 1
1846
+
1847
+ return noise, prev_sample
1848
+
1849
+
1850
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise
1851
+ def compute_noise(scheduler, *args):
1852
+ if isinstance(scheduler, DDIMScheduler):
1853
+ return compute_noise_ddim(scheduler, *args)
1854
+ elif (
1855
+ isinstance(scheduler, DPMSolverMultistepScheduler)
1856
+ and scheduler.config.algorithm_type == "sde-dpmsolver++"
1857
+ and scheduler.config.solver_order == 2
1858
+ ):
1859
+ return compute_noise_sde_dpm_pp_2nd(scheduler, *args)
1860
+ else:
1861
+ raise NotImplementedError
1862
+
1863
+
1864
  import gradio as gr
1865
  import spaces
1866
  import torch
 
1870
  import numpy as np
1871
  import cv2
1872
  from PIL import Image
1873
+ #from ledits.pipeline_leditspp_stable_diffusion_xl import LEditsPPPipelineStableDiffusionXL
1874
 
1875
  def HWC3(x):
1876
  assert x.dtype == np.uint8
 
2025
  return image
2026
 
2027
  @spaces.GPU
2028
+ def invert_image(image, num_inversion_steps=50, skip=0.3):
2029
  image = image.resize((512,512))
2030
  init_latents,zs = clip_slider_inv.pipe.invert(
2031
  source_prompt = "",
 
2197
  inputs=[slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2],
2198
  outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image])
2199
 
2200
+ image_inv.change(fn=reset_do_inversion, outputs=[do_inversion]).then(fn=invert_image, inputs=[image_inv], outputs=[init_latents,zs])
2201
  submit_inv.click(fn=generate,
2202
  inputs=[slider_x_inv, slider_y_inv, prompt_inv, seed_inv, iterations_inv, steps_inv, guidance_scale_inv, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, img2img_type_inv, image, controlnet_conditioning_scale, ip_adapter_scale ,edit_threshold, edit_guidance_scale, init_latents, zs],
2203
  outputs=[x_inv, y_inv, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image_inv])