ccchenzc commited on
Commit
f2f17f4
·
1 Parent(s): 4a985f1

Init demo.

Browse files
app.py CHANGED
@@ -1,7 +1,39 @@
1
- import gradio as gr
2
-
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
-
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from webui import (
3
+ create_interface_texture_synthesis,
4
+ create_interface_style_t2i,
5
+ create_interface_style_transfer,
6
+ Runner
7
+ )
8
+
9
+
10
+ import os
11
+ os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
12
+
13
+
14
+ def main():
15
+ runner = Runner()
16
+
17
+ with gr.Blocks(analytics_enabled=False,
18
+ title='Attention Distillation',
19
+ ) as demo:
20
+
21
+ md_txt = "# Attention Distillation" \
22
+ "\nOfficial demo of the paper [Attention Distillation: A Unified Approach to Visual Characteristics Transfer](https://arxiv.org/abs/2502.20235)"
23
+ gr.Markdown(md_txt)
24
+
25
+ with gr.Tabs(selected='tab_style_transfer'):
26
+ with gr.TabItem("Style Transfer", id='tab_style_transfer'):
27
+ create_interface_style_transfer(runner=runner)
28
+
29
+ with gr.TabItem("Style-Specific Text-to-Image Generation", id='tab_style_t2i'):
30
+ create_interface_style_t2i(runner=runner)
31
+
32
+ with gr.TabItem("Texture Synthesis", id='tab_texture_syn'):
33
+ create_interface_texture_synthesis(runner=runner)
34
+
35
+ demo.launch(share=False, debug=False)
36
+
37
+
38
+ if __name__ == '__main__':
39
+ main()
losses.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ loss_fn = torch.nn.L1Loss()
8
+
9
+
10
+ def ad_loss(
11
+ q_list, ks_list, vs_list, self_out_list, scale=1, source_mask=None, target_mask=None
12
+ ):
13
+ loss = 0
14
+ attn_mask = None
15
+ for q, ks, vs, self_out in zip(q_list, ks_list, vs_list, self_out_list):
16
+ if source_mask is not None and target_mask is not None:
17
+ w = h = int(np.sqrt(q.shape[2]))
18
+ mask_1 = torch.flatten(F.interpolate(source_mask, size=(h, w)))
19
+ mask_2 = torch.flatten(F.interpolate(target_mask, size=(h, w)))
20
+ attn_mask = mask_1.unsqueeze(0) == mask_2.unsqueeze(1)
21
+ attn_mask=attn_mask.to(q.device)
22
+
23
+ target_out = F.scaled_dot_product_attention(
24
+ q * scale,
25
+ torch.cat(torch.chunk(ks, ks.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
26
+ torch.cat(torch.chunk(vs, vs.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
27
+ attn_mask=attn_mask
28
+ )
29
+ loss += loss_fn(self_out, target_out.detach())
30
+ return loss
31
+
32
+
33
+
34
+ def q_loss(q_list, qc_list):
35
+ loss = 0
36
+ for q, qc in zip(q_list, qc_list):
37
+ loss += loss_fn(q, qc.detach())
38
+ return loss
39
+
40
+ # weight = 200
41
+ def qk_loss(q_list, k_list, qc_list, kc_list):
42
+ loss = 0
43
+ for q, k, qc, kc in zip(q_list, k_list, qc_list, kc_list):
44
+ scale_factor = 1 / math.sqrt(q.size(-1))
45
+ self_map = torch.softmax(q @ k.transpose(-2, -1) * scale_factor, dim=-1)
46
+ target_map = torch.softmax(qc @ kc.transpose(-2, -1) * scale_factor, dim=-1)
47
+ loss += loss_fn(self_map, target_map.detach())
48
+ return loss
49
+
50
+ # weight = 1
51
+ def qkv_loss(q_list, k_list, vc_list, c_out_list):
52
+ loss = 0
53
+ for q, k, vc, target_out in zip(q_list, k_list, vc_list, c_out_list):
54
+ self_out = F.scaled_dot_product_attention(q, k, vc)
55
+ loss += loss_fn(self_out, target_out.detach())
56
+ return loss
pipeline_sd.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import utils
8
+ from accelerate import Accelerator
9
+ from diffusers import StableDiffusionPipeline
10
+ from diffusers.image_processor import PipelineImageInput
11
+ from losses import *
12
+ from tqdm import tqdm
13
+
14
+
15
+ class ADPipeline(StableDiffusionPipeline):
16
+ def freeze(self):
17
+ self.vae.requires_grad_(False)
18
+ self.unet.requires_grad_(False)
19
+ self.text_encoder.requires_grad_(False)
20
+ self.classifier.requires_grad_(False)
21
+
22
+ @torch.no_grad()
23
+ def image2latent(self, image):
24
+ dtype = next(self.vae.parameters()).dtype
25
+ device = self._execution_device
26
+ image = image.to(device=device, dtype=dtype) * 2.0 - 1.0
27
+ latent = self.vae.encode(image)["latent_dist"].mean
28
+ latent = latent * self.vae.config.scaling_factor
29
+ return latent
30
+
31
+ @torch.no_grad()
32
+ def latent2image(self, latent):
33
+ dtype = next(self.vae.parameters()).dtype
34
+ device = self._execution_device
35
+ latent = latent.to(device=device, dtype=dtype)
36
+ latent = latent / self.vae.config.scaling_factor
37
+ image = self.vae.decode(latent)[0]
38
+ return (image * 0.5 + 0.5).clamp(0, 1)
39
+
40
+ def init(self, enable_gradient_checkpoint):
41
+ self.freeze()
42
+ weight_dtype = torch.float32
43
+ if self.accelerator.mixed_precision == "fp16":
44
+ weight_dtype = torch.float16
45
+ elif self.accelerator.mixed_precision == "bf16":
46
+ weight_dtype = torch.bfloat16
47
+
48
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
49
+ self.unet.to(self.accelerator.device, dtype=weight_dtype)
50
+ self.vae.to(self.accelerator.device, dtype=weight_dtype)
51
+ self.text_encoder.to(self.accelerator.device, dtype=weight_dtype)
52
+ self.classifier.to(self.accelerator.device, dtype=weight_dtype)
53
+ self.classifier = self.accelerator.prepare(self.classifier)
54
+ if enable_gradient_checkpoint:
55
+ self.classifier.enable_gradient_checkpointing()
56
+
57
+ def sample(
58
+ self,
59
+ lr=0.05,
60
+ iters=1,
61
+ attn_scale=1,
62
+ adain=False,
63
+ weight=0.25,
64
+ controller=None,
65
+ style_image=None,
66
+ content_image=None,
67
+ mixed_precision="no",
68
+ start_time=999,
69
+ enable_gradient_checkpoint=False,
70
+ prompt: Union[str, List[str]] = None,
71
+ height: Optional[int] = None,
72
+ width: Optional[int] = None,
73
+ num_inference_steps: int = 50,
74
+ guidance_scale: float = 7.5,
75
+ negative_prompt: Optional[Union[str, List[str]]] = None,
76
+ num_images_per_prompt: Optional[int] = 1,
77
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
78
+ latents: Optional[torch.Tensor] = None,
79
+ prompt_embeds: Optional[torch.Tensor] = None,
80
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
81
+ ip_adapter_image: Optional[PipelineImageInput] = None,
82
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
83
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
84
+ guidance_rescale: float = 0.0,
85
+ clip_skip: Optional[int] = None,
86
+ **kwargs,
87
+ ):
88
+ # 0. Default height and width to unet
89
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
90
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
91
+ self._guidance_scale = guidance_scale
92
+ self._guidance_rescale = guidance_rescale
93
+ self._clip_skip = clip_skip
94
+ self._cross_attention_kwargs = cross_attention_kwargs
95
+ self._interrupt = False
96
+
97
+ self.accelerator = Accelerator(
98
+ mixed_precision=mixed_precision, gradient_accumulation_steps=1
99
+ )
100
+ self.init(enable_gradient_checkpoint)
101
+
102
+ # 2. Define call parameters
103
+ if prompt is not None and isinstance(prompt, str):
104
+ batch_size = 1
105
+ elif prompt is not None and isinstance(prompt, list):
106
+ batch_size = len(prompt)
107
+ else:
108
+ batch_size = prompt_embeds.shape[0]
109
+
110
+ device = self._execution_device
111
+
112
+ # 3. Encode input prompt
113
+ lora_scale = (
114
+ self.cross_attention_kwargs.get("scale", None)
115
+ if self.cross_attention_kwargs is not None
116
+ else None
117
+ )
118
+ do_cfg = guidance_scale > 1.0
119
+
120
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
121
+ prompt,
122
+ device,
123
+ num_images_per_prompt,
124
+ do_cfg,
125
+ negative_prompt,
126
+ prompt_embeds=prompt_embeds,
127
+ negative_prompt_embeds=negative_prompt_embeds,
128
+ lora_scale=lora_scale,
129
+ clip_skip=self.clip_skip,
130
+ )
131
+
132
+ # For classifier free guidance, we need to do two forward passes.
133
+ # Here we concatenate the unconditional and text embeddings into a single batch
134
+ # to avoid doing two forward passes
135
+ if do_cfg:
136
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
137
+
138
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
139
+ image_embeds = self.prepare_ip_adapter_image_embeds(
140
+ ip_adapter_image,
141
+ ip_adapter_image_embeds,
142
+ device,
143
+ batch_size * num_images_per_prompt,
144
+ do_cfg,
145
+ )
146
+
147
+ # 5. Prepare latent variables
148
+ num_channels_latents = self.unet.config.in_channels
149
+ latents = self.prepare_latents(
150
+ batch_size * num_images_per_prompt,
151
+ num_channels_latents,
152
+ height,
153
+ width,
154
+ prompt_embeds.dtype,
155
+ device,
156
+ generator,
157
+ latents,
158
+ )
159
+
160
+ # 6.1 Add image embeds for IP-Adapter
161
+ added_cond_kwargs = (
162
+ {"image_embeds": image_embeds}
163
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
164
+ else None
165
+ )
166
+
167
+ # 6.2 Optionally get Guidance Scale Embedding
168
+ timestep_cond = None
169
+ if self.unet.config.time_cond_proj_dim is not None:
170
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
171
+ batch_size * num_images_per_prompt
172
+ )
173
+ timestep_cond = self.get_guidance_scale_embedding(
174
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
175
+ ).to(device=device, dtype=latents.dtype)
176
+
177
+ self.scheduler.set_timesteps(num_inference_steps)
178
+ timesteps = self.scheduler.timesteps
179
+ self.style_latent = self.image2latent(style_image)
180
+ if content_image is not None:
181
+ self.content_latent = self.image2latent(content_image)
182
+ else:
183
+ self.content_latent = None
184
+ null_embeds = self.encode_prompt("", device, 1, False)[0]
185
+ self.null_embeds = null_embeds
186
+ self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0])
187
+ self.null_embeds_for_style = torch.cat(
188
+ [null_embeds] * self.style_latent.shape[0]
189
+ )
190
+
191
+ self.adain = adain
192
+ self.attn_scale = attn_scale
193
+ self.cache = utils.DataCache()
194
+ self.controller = controller
195
+ utils.register_attn_control(
196
+ self.classifier, controller=self.controller, cache=self.cache
197
+ )
198
+ print("Total self attention layers of Unet: ", controller.num_self_layers)
199
+ print("Self attention layers for AD: ", controller.self_layers)
200
+
201
+ pbar = tqdm(timesteps, desc="Sample")
202
+ for i, t in enumerate(pbar):
203
+ with torch.no_grad():
204
+ # expand the latents if we are doing classifier free guidance
205
+ latent_model_input = torch.cat([latents] * 2) if do_cfg else latents
206
+ latent_model_input = self.scheduler.scale_model_input(
207
+ latent_model_input, t
208
+ )
209
+ # predict the noise residual
210
+ noise_pred = self.unet(
211
+ latent_model_input,
212
+ t,
213
+ encoder_hidden_states=prompt_embeds,
214
+ timestep_cond=timestep_cond,
215
+ cross_attention_kwargs=self.cross_attention_kwargs,
216
+ added_cond_kwargs=added_cond_kwargs,
217
+ return_dict=False,
218
+ )[0]
219
+
220
+ # perform guidance
221
+ if do_cfg:
222
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
223
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
224
+ noise_pred_text - noise_pred_uncond
225
+ )
226
+ latents = self.scheduler.step(
227
+ noise_pred, t, latents, return_dict=False
228
+ )[0]
229
+ if iters > 0 and t < start_time:
230
+ latents = self.AD(latents, t, lr, iters, pbar, weight)
231
+
232
+ images = self.latent2image(latents)
233
+ # Offload all models
234
+ self.maybe_free_model_hooks()
235
+ return images
236
+
237
+ def optimize(
238
+ self,
239
+ latents=None,
240
+ attn_scale=1.0,
241
+ lr=0.05,
242
+ iters=1,
243
+ weight=0,
244
+ width=512,
245
+ height=512,
246
+ batch_size=1,
247
+ controller=None,
248
+ style_image=None,
249
+ content_image=None,
250
+ mixed_precision="no",
251
+ num_inference_steps=50,
252
+ enable_gradient_checkpoint=False,
253
+ source_mask=None,
254
+ target_mask=None,
255
+ ):
256
+ height = height // self.vae_scale_factor
257
+ width = width // self.vae_scale_factor
258
+
259
+ self.accelerator = Accelerator(
260
+ mixed_precision=mixed_precision, gradient_accumulation_steps=1
261
+ )
262
+ self.init(enable_gradient_checkpoint)
263
+
264
+ style_latent = self.image2latent(style_image)
265
+ latents = torch.randn((batch_size, 4, height, width), device=self.device)
266
+ null_embeds = self.encode_prompt("", self.device, 1, False)[0]
267
+ null_embeds_for_latents = null_embeds.repeat(latents.shape[0], 1, 1)
268
+ null_embeds_for_style = null_embeds.repeat(style_latent.shape[0], 1, 1)
269
+
270
+ if content_image is not None:
271
+ content_latent = self.image2latent(content_image)
272
+ latents = torch.cat([content_latent.clone()] * batch_size)
273
+ null_embeds_for_content = null_embeds.repeat(content_latent.shape[0], 1, 1)
274
+
275
+ self.cache = utils.DataCache()
276
+ self.controller = controller
277
+ utils.register_attn_control(
278
+ self.classifier, controller=self.controller, cache=self.cache
279
+ )
280
+ print("Total self attention layers of Unet: ", controller.num_self_layers)
281
+ print("Self attention layers for AD: ", controller.self_layers)
282
+
283
+ self.scheduler.set_timesteps(num_inference_steps)
284
+ timesteps = self.scheduler.timesteps
285
+ latents = latents.detach().float()
286
+ optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
287
+ optimizer = self.accelerator.prepare(optimizer)
288
+ pbar = tqdm(timesteps, desc="Optimize")
289
+ for i, t in enumerate(pbar):
290
+ # t = torch.tensor([1], device=self.device)
291
+ with torch.no_grad():
292
+ qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
293
+ style_latent,
294
+ t,
295
+ null_embeds_for_style,
296
+ )
297
+ if content_image is not None:
298
+ qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
299
+ content_latent,
300
+ t,
301
+ null_embeds_for_content,
302
+ )
303
+ for j in range(iters):
304
+ style_loss = 0
305
+ content_loss = 0
306
+ optimizer.zero_grad()
307
+ q_list, k_list, v_list, self_out_list = self.extract_feature(
308
+ latents,
309
+ t,
310
+ null_embeds_for_latents,
311
+ )
312
+ style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=attn_scale, source_mask=source_mask, target_mask=target_mask)
313
+ if content_image is not None:
314
+ content_loss = q_loss(q_list, qc_list)
315
+ # content_loss = qk_loss(q_list, k_list, qc_list, kc_list)
316
+ # content_loss = qkv_loss(q_list, k_list, vc_list, c_out_list)
317
+ loss = style_loss + content_loss * weight
318
+ self.accelerator.backward(loss)
319
+ optimizer.step()
320
+ pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
321
+ images = self.latent2image(latents)
322
+ # Offload all models
323
+ self.maybe_free_model_hooks()
324
+ return images
325
+
326
+ def panorama(
327
+ self,
328
+ lr=0.05,
329
+ iters=1,
330
+ attn_scale=1,
331
+ adain=False,
332
+ controller=None,
333
+ style_image=None,
334
+ mixed_precision="no",
335
+ enable_gradient_checkpoint=False,
336
+ prompt: Union[str, List[str]] = None,
337
+ height: Optional[int] = None,
338
+ width: Optional[int] = None,
339
+ num_inference_steps: int = 50,
340
+ guidance_scale: float = 1,
341
+ stride=8,
342
+ view_batch_size: int = 16,
343
+ negative_prompt: Optional[Union[str, List[str]]] = None,
344
+ num_images_per_prompt: Optional[int] = 1,
345
+ eta: float = 0.0,
346
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
347
+ latents: Optional[torch.Tensor] = None,
348
+ prompt_embeds: Optional[torch.Tensor] = None,
349
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
350
+ ip_adapter_image: Optional[PipelineImageInput] = None,
351
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
352
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
353
+ guidance_rescale: float = 0.0,
354
+ clip_skip: Optional[int] = None,
355
+ **kwargs,
356
+ ):
357
+
358
+ # 0. Default height and width to unet
359
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
360
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
361
+
362
+ self._guidance_scale = guidance_scale
363
+ self._guidance_rescale = guidance_rescale
364
+ self._clip_skip = clip_skip
365
+ self._cross_attention_kwargs = cross_attention_kwargs
366
+ self._interrupt = False
367
+
368
+ self.accelerator = Accelerator(
369
+ mixed_precision=mixed_precision, gradient_accumulation_steps=1
370
+ )
371
+ self.init(enable_gradient_checkpoint)
372
+
373
+ # 2. Define call parameters
374
+ if prompt is not None and isinstance(prompt, str):
375
+ batch_size = 1
376
+ elif prompt is not None and isinstance(prompt, list):
377
+ batch_size = len(prompt)
378
+ else:
379
+ batch_size = prompt_embeds.shape[0]
380
+
381
+ device = self._execution_device
382
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
383
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
384
+ # corresponds to doing no classifier free guidance.
385
+ do_cfg = guidance_scale > 1.0
386
+
387
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
388
+ image_embeds = self.prepare_ip_adapter_image_embeds(
389
+ ip_adapter_image,
390
+ ip_adapter_image_embeds,
391
+ device,
392
+ batch_size * num_images_per_prompt,
393
+ self.do_classifier_free_guidance,
394
+ )
395
+
396
+ # 3. Encode input prompt
397
+ text_encoder_lora_scale = (
398
+ cross_attention_kwargs.get("scale", None)
399
+ if cross_attention_kwargs is not None
400
+ else None
401
+ )
402
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
403
+ prompt,
404
+ device,
405
+ num_images_per_prompt,
406
+ do_cfg,
407
+ negative_prompt,
408
+ prompt_embeds=prompt_embeds,
409
+ negative_prompt_embeds=negative_prompt_embeds,
410
+ lora_scale=text_encoder_lora_scale,
411
+ clip_skip=clip_skip,
412
+ )
413
+ # For classifier free guidance, we need to do two forward passes.
414
+ # Here we concatenate the unconditional and text embeddings into a single batch
415
+ # to avoid doing two forward passes
416
+ if do_cfg:
417
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
418
+
419
+ # 5. Prepare latent variables
420
+ num_channels_latents = self.unet.config.in_channels
421
+ latents = self.prepare_latents(
422
+ batch_size * num_images_per_prompt,
423
+ num_channels_latents,
424
+ height,
425
+ width,
426
+ prompt_embeds.dtype,
427
+ device,
428
+ generator,
429
+ latents,
430
+ )
431
+
432
+ # 6. Define panorama grid and initialize views for synthesis.
433
+ # prepare batch grid
434
+ views = self.get_views_(height, width, window_size=64, stride=stride)
435
+ views_batch = [
436
+ views[i : i + view_batch_size]
437
+ for i in range(0, len(views), view_batch_size)
438
+ ]
439
+ print(len(views), len(views_batch), views_batch)
440
+ self.scheduler.set_timesteps(num_inference_steps)
441
+ views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(
442
+ views_batch
443
+ )
444
+ count = torch.zeros_like(latents)
445
+ value = torch.zeros_like(latents)
446
+
447
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
448
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
449
+
450
+ # 7.1 Add image embeds for IP-Adapter
451
+ added_cond_kwargs = (
452
+ {"image_embeds": image_embeds}
453
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
454
+ else None
455
+ )
456
+
457
+ # 7.2 Optionally get Guidance Scale Embedding
458
+ timestep_cond = None
459
+ if self.unet.config.time_cond_proj_dim is not None:
460
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
461
+ batch_size * num_images_per_prompt
462
+ )
463
+ timestep_cond = self.get_guidance_scale_embedding(
464
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
465
+ ).to(device=device, dtype=latents.dtype)
466
+
467
+ # 8. Denoising loop
468
+ # Each denoising step also includes refinement of the latents with respect to the
469
+ # views.
470
+
471
+ timesteps = self.scheduler.timesteps
472
+ self.style_latent = self.image2latent(style_image)
473
+ self.content_latent = None
474
+ null_embeds = self.encode_prompt("", device, 1, False)[0]
475
+ self.null_embeds = null_embeds
476
+ self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0])
477
+ self.null_embeds_for_style = torch.cat(
478
+ [null_embeds] * self.style_latent.shape[0]
479
+ )
480
+ self.adain = adain
481
+ self.attn_scale = attn_scale
482
+ self.cache = utils.DataCache()
483
+ self.controller = controller
484
+ utils.register_attn_control(
485
+ self.classifier, controller=self.controller, cache=self.cache
486
+ )
487
+ print("Total self attention layers of Unet: ", controller.num_self_layers)
488
+ print("Self attention layers for AD: ", controller.self_layers)
489
+
490
+ pbar = tqdm(timesteps, desc="Sample")
491
+ for i, t in enumerate(pbar):
492
+ count.zero_()
493
+ value.zero_()
494
+ # generate views
495
+ # Here, we iterate through different spatial crops of the latents and denoise them. These
496
+ # denoised (latent) crops are then averaged to produce the final latent
497
+ # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
498
+ # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
499
+ # Batch views denoise
500
+ for j, batch_view in enumerate(views_batch):
501
+ vb_size = len(batch_view)
502
+ # get the latents corresponding to the current view coordinates
503
+ latents_for_view = torch.cat(
504
+ [
505
+ latents[:, :, h_start:h_end, w_start:w_end]
506
+ for h_start, h_end, w_start, w_end in batch_view
507
+ ]
508
+ )
509
+ # rematch block's scheduler status
510
+ self.scheduler.__dict__.update(views_scheduler_status[j])
511
+
512
+ # expand the latents if we are doing classifier free guidance
513
+ latent_model_input = (
514
+ latents_for_view.repeat_interleave(2, dim=0)
515
+ if do_cfg
516
+ else latents_for_view
517
+ )
518
+
519
+ latent_model_input = self.scheduler.scale_model_input(
520
+ latent_model_input, t
521
+ )
522
+
523
+ # repeat prompt_embeds for batch
524
+ prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
525
+
526
+ # predict the noise residual
527
+ with torch.no_grad():
528
+ noise_pred = self.unet(
529
+ latent_model_input,
530
+ t,
531
+ encoder_hidden_states=prompt_embeds_input,
532
+ timestep_cond=timestep_cond,
533
+ cross_attention_kwargs=cross_attention_kwargs,
534
+ added_cond_kwargs=added_cond_kwargs,
535
+ ).sample
536
+
537
+ # perform guidance
538
+ if do_cfg:
539
+ noise_pred_uncond, noise_pred_text = (
540
+ noise_pred[::2],
541
+ noise_pred[1::2],
542
+ )
543
+ noise_pred = noise_pred_uncond + guidance_scale * (
544
+ noise_pred_text - noise_pred_uncond
545
+ )
546
+
547
+ # compute the previous noisy sample x_t -> x_t-1
548
+ latents_denoised_batch = self.scheduler.step(
549
+ noise_pred, t, latents_for_view, **extra_step_kwargs
550
+ ).prev_sample
551
+ if iters > 0:
552
+ self.null_embeds_for_latents = torch.cat(
553
+ [self.null_embeds] * noise_pred.shape[0]
554
+ )
555
+ latents_denoised_batch = self.AD(
556
+ latents_denoised_batch, t, lr, iters, pbar
557
+ )
558
+ # save views scheduler status after sample
559
+ views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
560
+
561
+ # extract value from batch
562
+ for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
563
+ latents_denoised_batch.chunk(vb_size), batch_view
564
+ ):
565
+
566
+ value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
567
+ count[:, :, h_start:h_end, w_start:w_end] += 1
568
+
569
+ # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
570
+ latents = torch.where(count > 0, value / count, value)
571
+
572
+ images = self.latent2image(latents)
573
+ # Offload all models
574
+ self.maybe_free_model_hooks()
575
+ return images
576
+
577
+ def AD(self, latents, t, lr, iters, pbar, weight=0):
578
+ t = max(
579
+ t
580
+ - self.scheduler.config.num_train_timesteps
581
+ // self.scheduler.num_inference_steps,
582
+ torch.tensor([0], device=self.device),
583
+ )
584
+ if self.adain:
585
+ noise = torch.randn_like(self.style_latent)
586
+ style_latent = self.scheduler.add_noise(self.style_latent, noise, t)
587
+ latents = utils.adain(latents, style_latent)
588
+
589
+ with torch.no_grad():
590
+ qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
591
+ self.style_latent,
592
+ t,
593
+ self.null_embeds_for_style,
594
+ add_noise=True,
595
+ )
596
+ if self.content_latent is not None:
597
+ qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
598
+ self.content_latent,
599
+ t,
600
+ self.null_embeds,
601
+ add_noise=True,
602
+ )
603
+
604
+ latents = latents.detach()
605
+ optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
606
+ optimizer = self.accelerator.prepare(optimizer)
607
+
608
+ for j in range(iters):
609
+ style_loss = 0
610
+ content_loss = 0
611
+ optimizer.zero_grad()
612
+ q_list, k_list, v_list, self_out_list = self.extract_feature(
613
+ latents,
614
+ t,
615
+ self.null_embeds_for_latents,
616
+ add_noise=False,
617
+ )
618
+ style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=self.attn_scale)
619
+ if self.content_latent is not None:
620
+ content_loss = q_loss(q_list, qc_list)
621
+ # content_loss = qk_loss(q_list, k_list, qc_list, kc_list)
622
+ # content_loss = qkv_loss(q_list, k_list, vc_list, c_out_list)
623
+ loss = style_loss + content_loss * weight
624
+ self.accelerator.backward(loss)
625
+ optimizer.step()
626
+
627
+ pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
628
+ latents = latents.detach()
629
+ return latents
630
+
631
+ def extract_feature(
632
+ self,
633
+ latent,
634
+ t,
635
+ embeds,
636
+ add_noise=False,
637
+ ):
638
+ self.cache.clear()
639
+ self.controller.step()
640
+ if add_noise:
641
+ noise = torch.randn_like(latent)
642
+ latent_ = self.scheduler.add_noise(latent, noise, t)
643
+ else:
644
+ latent_ = latent
645
+ _ = self.classifier(latent_, t, embeds)[0]
646
+ return self.cache.get()
647
+
648
+ def get_views_(
649
+ self,
650
+ panorama_height: int,
651
+ panorama_width: int,
652
+ window_size: int = 64,
653
+ stride: int = 8,
654
+ ) -> List[Tuple[int, int, int, int]]:
655
+ panorama_height //= 8
656
+ panorama_width //= 8
657
+
658
+ num_blocks_height = (
659
+ math.ceil((panorama_height - window_size) / stride) + 1
660
+ if panorama_height > window_size
661
+ else 1
662
+ )
663
+ num_blocks_width = (
664
+ math.ceil((panorama_width - window_size) / stride) + 1
665
+ if panorama_width > window_size
666
+ else 1
667
+ )
668
+
669
+ views = []
670
+ for i in range(int(num_blocks_height)):
671
+ for j in range(int(num_blocks_width)):
672
+ h_start = int(min(i * stride, panorama_height - window_size))
673
+ w_start = int(min(j * stride, panorama_width - window_size))
674
+
675
+ h_end = h_start + window_size
676
+ w_end = w_start + window_size
677
+
678
+ views.append((h_start, h_end, w_start, w_end))
679
+
680
+ return views
pipeline_sdxl.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import utils
7
+ from accelerate import Accelerator
8
+ from accelerate.utils import (
9
+ DistributedDataParallelKwargs,
10
+ ProjectConfiguration,
11
+ set_seed,
12
+ )
13
+ from diffusers import StableDiffusionXLPipeline
14
+ from diffusers.image_processor import PipelineImageInput
15
+ from diffusers.utils.torch_utils import is_compiled_module
16
+ from losses import *
17
+
18
+ # from peft import LoraConfig, set_peft_model_state_dict
19
+ from tqdm import tqdm
20
+
21
+
22
+ class ADPipeline(StableDiffusionXLPipeline):
23
+ def freeze(self):
24
+ self.unet.requires_grad_(False)
25
+ self.text_encoder.requires_grad_(False)
26
+ self.text_encoder_2.requires_grad_(False)
27
+ self.vae.requires_grad_(False)
28
+ self.classifier.requires_grad_(False)
29
+
30
+ @torch.no_grad()
31
+ def image2latent(self, image):
32
+ dtype = next(self.vae.parameters()).dtype
33
+ device = self._execution_device
34
+ image = image.to(device=device, dtype=dtype) * 2.0 - 1.0
35
+ latent = self.vae.encode(image)["latent_dist"].mean
36
+ latent = latent * self.vae.config.scaling_factor
37
+ return latent
38
+
39
+ @torch.no_grad()
40
+ def latent2image(self, latent):
41
+ dtype = next(self.vae.parameters()).dtype
42
+ device = self._execution_device
43
+ latent = latent.to(device=device, dtype=dtype)
44
+ latent = latent / self.vae.config.scaling_factor
45
+ image = self.vae.decode(latent)[0]
46
+ return (image * 0.5 + 0.5).clamp(0, 1)
47
+
48
+ def init(self, enable_gradient_checkpoint):
49
+ self.freeze()
50
+ weight_dtype = torch.float32
51
+ if self.accelerator.mixed_precision == "fp16":
52
+ weight_dtype = torch.float16
53
+ elif self.accelerator.mixed_precision == "bf16":
54
+ weight_dtype = torch.bfloat16
55
+
56
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
57
+ self.unet.to(self.accelerator.device, dtype=weight_dtype)
58
+ self.vae.to(self.accelerator.device, dtype=weight_dtype)
59
+ self.text_encoder.to(self.accelerator.device, dtype=weight_dtype)
60
+ self.text_encoder_2.to(self.accelerator.device, dtype=weight_dtype)
61
+ self.classifier.to(self.accelerator.device, dtype=weight_dtype)
62
+ self.classifier = self.accelerator.prepare(self.classifier)
63
+ if enable_gradient_checkpoint:
64
+ self.classifier.enable_gradient_checkpointing()
65
+ # self.classifier.train()
66
+
67
+
68
+ def sample(
69
+ self,
70
+ lr=0.05,
71
+ iters=1,
72
+ adain=True,
73
+ controller=None,
74
+ style_image=None,
75
+ mixed_precision="no",
76
+ init_from_style=False,
77
+ start_time=999,
78
+ prompt: Union[str, List[str]] = None,
79
+ prompt_2: Optional[Union[str, List[str]]] = None,
80
+ height: Optional[int] = None,
81
+ width: Optional[int] = None,
82
+ num_inference_steps: int = 50,
83
+ denoising_end: Optional[float] = None,
84
+ guidance_scale: float = 5.0,
85
+ negative_prompt: Optional[Union[str, List[str]]] = None,
86
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
87
+ num_images_per_prompt: Optional[int] = 1,
88
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
89
+ latents: Optional[torch.Tensor] = None,
90
+ prompt_embeds: Optional[torch.Tensor] = None,
91
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
92
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
93
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
94
+ ip_adapter_image: Optional[PipelineImageInput] = None,
95
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
96
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
97
+ guidance_rescale: float = 0.0,
98
+ original_size: Optional[Tuple[int, int]] = None,
99
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
100
+ target_size: Optional[Tuple[int, int]] = None,
101
+ negative_original_size: Optional[Tuple[int, int]] = None,
102
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
103
+ negative_target_size: Optional[Tuple[int, int]] = None,
104
+ clip_skip: Optional[int] = None,
105
+ enable_gradient_checkpoint=False,
106
+ **kwargs,
107
+ ):
108
+ # 0. Default height and width to unet
109
+ height = height or self.default_sample_size * self.vae_scale_factor
110
+ width = width or self.default_sample_size * self.vae_scale_factor
111
+
112
+ original_size = original_size or (height, width)
113
+ target_size = target_size or (height, width)
114
+ self._guidance_scale = guidance_scale
115
+ self._guidance_rescale = guidance_rescale
116
+ self._clip_skip = clip_skip
117
+ self._cross_attention_kwargs = cross_attention_kwargs
118
+ self._denoising_end = denoising_end
119
+ self._interrupt = False
120
+
121
+ self.accelerator = Accelerator(
122
+ mixed_precision=mixed_precision, gradient_accumulation_steps=1
123
+ )
124
+ self.init(enable_gradient_checkpoint)
125
+
126
+ # 2. Define call parameters
127
+ if prompt is not None and isinstance(prompt, str):
128
+ batch_size = 1
129
+ elif prompt is not None and isinstance(prompt, list):
130
+ batch_size = len(prompt)
131
+ else:
132
+ batch_size = prompt_embeds.shape[0]
133
+
134
+ device = self._execution_device
135
+
136
+ # 3. Encode input prompt
137
+ lora_scale = (
138
+ self.cross_attention_kwargs.get("scale", None)
139
+ if self.cross_attention_kwargs is not None
140
+ else None
141
+ )
142
+
143
+ (
144
+ prompt_embeds,
145
+ negative_prompt_embeds,
146
+ pooled_prompt_embeds,
147
+ negative_pooled_prompt_embeds,
148
+ ) = self.encode_prompt(
149
+ prompt=prompt,
150
+ prompt_2=prompt_2,
151
+ device=device,
152
+ num_images_per_prompt=num_images_per_prompt,
153
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
154
+ negative_prompt=negative_prompt,
155
+ negative_prompt_2=negative_prompt_2,
156
+ prompt_embeds=prompt_embeds,
157
+ negative_prompt_embeds=negative_prompt_embeds,
158
+ pooled_prompt_embeds=pooled_prompt_embeds,
159
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
160
+ lora_scale=lora_scale,
161
+ clip_skip=self.clip_skip,
162
+ )
163
+
164
+ # 5. Prepare latent variables
165
+ num_channels_latents = self.unet.config.in_channels
166
+ latents = self.prepare_latents(
167
+ batch_size * num_images_per_prompt,
168
+ num_channels_latents,
169
+ height,
170
+ width,
171
+ prompt_embeds.dtype,
172
+ device,
173
+ generator,
174
+ latents,
175
+ )
176
+
177
+ # 7. Prepare added time ids & embeddings
178
+ add_text_embeds = pooled_prompt_embeds
179
+ if self.text_encoder_2 is None:
180
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
181
+ else:
182
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
183
+
184
+ add_time_ids = self._get_add_time_ids(
185
+ original_size,
186
+ crops_coords_top_left,
187
+ target_size,
188
+ dtype=prompt_embeds.dtype,
189
+ text_encoder_projection_dim=text_encoder_projection_dim,
190
+ )
191
+ null_add_time_ids = add_time_ids.to(device)
192
+ if negative_original_size is not None and negative_target_size is not None:
193
+ negative_add_time_ids = self._get_add_time_ids(
194
+ negative_original_size,
195
+ negative_crops_coords_top_left,
196
+ negative_target_size,
197
+ dtype=prompt_embeds.dtype,
198
+ text_encoder_projection_dim=text_encoder_projection_dim,
199
+ )
200
+ else:
201
+ negative_add_time_ids = add_time_ids
202
+
203
+ if self.do_classifier_free_guidance:
204
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
205
+ add_text_embeds = torch.cat(
206
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
207
+ )
208
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
209
+
210
+ prompt_embeds = prompt_embeds.to(device)
211
+ add_text_embeds = add_text_embeds.to(device)
212
+ add_time_ids = add_time_ids.to(device).repeat(
213
+ batch_size * num_images_per_prompt, 1
214
+ )
215
+
216
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
217
+ image_embeds = self.prepare_ip_adapter_image_embeds(
218
+ ip_adapter_image,
219
+ ip_adapter_image_embeds,
220
+ device,
221
+ batch_size * num_images_per_prompt,
222
+ self.do_classifier_free_guidance,
223
+ )
224
+ # 8.1 Apply denoising_end
225
+ if (
226
+ self.denoising_end is not None
227
+ and isinstance(self.denoising_end, float)
228
+ and self.denoising_end > 0
229
+ and self.denoising_end < 1
230
+ ):
231
+ discrete_timestep_cutoff = int(
232
+ round(
233
+ self.scheduler.config.num_train_timesteps
234
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
235
+ )
236
+ )
237
+ num_inference_steps = len(
238
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
239
+ )
240
+ timesteps = timesteps[:num_inference_steps]
241
+
242
+ # 9. Optionally get Guidance Scale Embedding
243
+ timestep_cond = None
244
+ if self.unet.config.time_cond_proj_dim is not None:
245
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
246
+ batch_size * num_images_per_prompt
247
+ )
248
+ timestep_cond = self.get_guidance_scale_embedding(
249
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
250
+ ).to(device=device, dtype=latents.dtype)
251
+ self.timestep_cond = timestep_cond
252
+ (null_embeds, _, null_pooled_embeds, _) = self.encode_prompt("", device=device)
253
+
254
+ added_cond_kwargs = {
255
+ "text_embeds": add_text_embeds,
256
+ "time_ids": add_time_ids
257
+ }
258
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
259
+ added_cond_kwargs["image_embeds"] = image_embeds
260
+
261
+ self.scheduler.set_timesteps(num_inference_steps)
262
+
263
+ timesteps = self.scheduler.timesteps
264
+ style_latent = self.image2latent(style_image)
265
+ if init_from_style:
266
+ latents = torch.cat([style_latent] * latents.shape[0])
267
+ noise = torch.randn_like(latents)
268
+ latents = self.scheduler.add_noise(
269
+ latents,
270
+ noise,
271
+ torch.tensor([999]),
272
+ )
273
+
274
+ self.style_latent = style_latent
275
+ self.null_embeds_for_latents = torch.cat([null_embeds] * (latents.shape[0]))
276
+ self.null_embeds_for_style = torch.cat([null_embeds] * style_latent.shape[0])
277
+ self.null_added_cond_kwargs_for_latents = {
278
+ "text_embeds": torch.cat([null_pooled_embeds] * (latents.shape[0])),
279
+ "time_ids": torch.cat([null_add_time_ids] * (latents.shape[0])),
280
+ }
281
+ self.null_added_cond_kwargs_for_style = {
282
+ "text_embeds": torch.cat([null_pooled_embeds] * style_latent.shape[0]),
283
+ "time_ids": torch.cat([null_add_time_ids] * style_latent.shape[0]),
284
+ }
285
+ self.adain = adain
286
+ self.cache = utils.DataCache()
287
+ self.controller = controller
288
+ utils.register_attn_control(
289
+ self.classifier, controller=controller, cache=self.cache
290
+ )
291
+ print("Total self attention layers of Unet: ", controller.num_self_layers)
292
+ print("Self attention layers for AD: ", controller.self_layers)
293
+
294
+ pbar = tqdm(timesteps, desc="Sample")
295
+ for i, t in enumerate(pbar):
296
+ with torch.no_grad():
297
+ # expand the latents if we are doing classifier free guidance
298
+ latent_model_input = (
299
+ torch.cat([latents] * 2)
300
+ if self.do_classifier_free_guidance
301
+ else latents
302
+ )
303
+
304
+ # predict the noise residual
305
+ noise_pred = self.unet(
306
+ latent_model_input,
307
+ t,
308
+ encoder_hidden_states=prompt_embeds,
309
+ timestep_cond=timestep_cond,
310
+ cross_attention_kwargs=self.cross_attention_kwargs,
311
+ added_cond_kwargs=added_cond_kwargs,
312
+ return_dict=False,
313
+ )[0]
314
+
315
+ # perform guidance
316
+ if self.do_classifier_free_guidance:
317
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
318
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
319
+ noise_pred_text - noise_pred_uncond
320
+ )
321
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
322
+
323
+ if iters > 0 and t < start_time:
324
+ latents = self.AD(latents, t, lr, iters, pbar)
325
+
326
+
327
+ # Offload all models
328
+ # self.enable_model_cpu_offload()
329
+ images = self.latent2image(latents)
330
+ self.maybe_free_model_hooks()
331
+ return images
332
+
333
+ def AD(self, latents, t, lr, iters, pbar):
334
+ t = max(
335
+ t
336
+ - self.scheduler.config.num_train_timesteps
337
+ // self.scheduler.num_inference_steps,
338
+ torch.tensor([0], device=self.device),
339
+ )
340
+
341
+ if self.adain:
342
+ noise = torch.randn_like(self.style_latent)
343
+ style_latent = self.scheduler.add_noise(self.style_latent, noise, t)
344
+ latents = utils.adain(latents, style_latent)
345
+
346
+ with torch.no_grad():
347
+ qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
348
+ self.style_latent,
349
+ t,
350
+ self.null_embeds_for_style,
351
+ self.timestep_cond,
352
+ self.null_added_cond_kwargs_for_style,
353
+ add_noise=True,
354
+ )
355
+ # latents = latents.to(dtype=torch.float32)
356
+ latents = latents.detach()
357
+ optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
358
+ optimizer, latents = self.accelerator.prepare(optimizer, latents)
359
+
360
+ for j in range(iters):
361
+ optimizer.zero_grad()
362
+ q_list, k_list, v_list, self_out_list = self.extract_feature(
363
+ latents,
364
+ t,
365
+ self.null_embeds_for_latents,
366
+ self.timestep_cond,
367
+ self.null_added_cond_kwargs_for_latents,
368
+ add_noise=False,
369
+ )
370
+
371
+ loss = ad_loss(q_list, ks_list, vs_list, self_out_list)
372
+ self.accelerator.backward(loss)
373
+ optimizer.step()
374
+
375
+ pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
376
+ latents = latents.detach()
377
+ return latents
378
+
379
+ def extract_feature(
380
+ self,
381
+ latent,
382
+ t,
383
+ encoder_hidden_states,
384
+ timestep_cond,
385
+ added_cond_kwargs,
386
+ add_noise=False,
387
+ ):
388
+ self.cache.clear()
389
+ self.controller.step()
390
+ if add_noise:
391
+ noise = torch.randn_like(latent)
392
+ latent_ = self.scheduler.add_noise(latent, noise, t)
393
+ else:
394
+ latent_ = latent
395
+ self.classifier(
396
+ latent_,
397
+ t,
398
+ encoder_hidden_states=encoder_hidden_states,
399
+ timestep_cond=timestep_cond,
400
+ added_cond_kwargs=added_cond_kwargs,
401
+ return_dict=False,
402
+ )[0]
403
+ return self.cache.get()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ torch>=2.0.0
3
+ torchvision
4
+ transformers
5
+ accelerate
6
+ safetensors
7
+ spaces
8
+ huggingface-hub
9
+ gradio
10
+ matplotlib
train_vae.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from diffusers import AutoencoderKL
6
+ from torch import nn
7
+ from torch.optim import Adam
8
+ from utils import load_image, save_image
9
+
10
+
11
+ def main(args):
12
+ os.makedirs(args.out_dir, exist_ok=True)
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ vae = AutoencoderKL.from_pretrained(args.vae_model_path).to(
16
+ device, dtype=torch.float32
17
+ )
18
+ vae.requires_grad_(False)
19
+
20
+ image = load_image(args.image_path, size=(512, 512)).to(device, dtype=torch.float32)
21
+ image = image * 2 - 1
22
+ save_image(image / 2 + 0.5, f"{args.out_dir}/ori_image.png")
23
+
24
+ latents = vae.encode(image)["latent_dist"].mean
25
+ save_image(latents, f"{args.out_dir}/latents.png")
26
+
27
+ rec_image = vae.decode(latents, return_dict=False)[0]
28
+ save_image(rec_image / 2 + 0.5, f"{args.out_dir}/rec_image.png")
29
+
30
+ for param in vae.decoder.parameters():
31
+ param.requires_grad = True
32
+
33
+ loss_fn = nn.L1Loss()
34
+ optimizer = Adam(vae.decoder.parameters(), lr=args.learning_rate)
35
+
36
+ # Training loop
37
+ for epoch in range(args.num_epochs):
38
+ reconstructed = vae.decode(latents, return_dict=False)[0]
39
+ loss = loss_fn(reconstructed, image)
40
+
41
+ optimizer.zero_grad()
42
+ loss.backward()
43
+ optimizer.step()
44
+
45
+ print(f"Epoch {epoch+1}/{args.num_epochs}, Loss: {loss.item()}")
46
+
47
+ rec_image = vae.decode(latents, return_dict=False)[0]
48
+ save_image(rec_image / 2 + 0.5, f"{args.out_dir}/trained_rec_image.png")
49
+ vae.save_pretrained(
50
+ f"{args.out_dir}/trained_vae_{os.path.basename(args.image_path)}"
51
+ )
52
+
53
+
54
+
55
+ if __name__ == "__main__":
56
+ parser = argparse.ArgumentParser(
57
+ description="Train a VAE with given image and settings."
58
+ )
59
+
60
+ # Add arguments
61
+ parser.add_argument(
62
+ "--out_dir",
63
+ type=str,
64
+ default="./trained_vae/",
65
+ help="Output directory to save results",
66
+ )
67
+ parser.add_argument(
68
+ "--vae_model_path",
69
+ type=str,
70
+ required=True,
71
+ help="Path to the pretrained VAE model",
72
+ )
73
+ parser.add_argument(
74
+ "--image_path", type=str, required=True, help="Path to the input image"
75
+ )
76
+ parser.add_argument(
77
+ "--learning_rate",
78
+ type=float,
79
+ default=1e-4,
80
+ help="Learning rate for the optimizer",
81
+ )
82
+ parser.add_argument(
83
+ "--num_epochs", type=int, default=75, help="Number of training epochs"
84
+ )
85
+
86
+ args = parser.parse_args()
87
+ main(args)
utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ from torchvision.transforms import ToTensor
6
+ from torchvision.utils import save_image
7
+ import matplotlib.pyplot as plt
8
+ import math
9
+
10
+
11
+ def register_attn_control(unet, controller, cache=None):
12
+ def attn_forward(self):
13
+ def forward(
14
+ hidden_states,
15
+ encoder_hidden_states=None,
16
+ attention_mask=None,
17
+ temb=None,
18
+ *args,
19
+ **kwargs,
20
+ ):
21
+ residual = hidden_states
22
+ if self.spatial_norm is not None:
23
+ hidden_states = self.spatial_norm(hidden_states, temb)
24
+
25
+ input_ndim = hidden_states.ndim
26
+
27
+ if input_ndim == 4:
28
+ batch_size, channel, height, width = hidden_states.shape
29
+ hidden_states = hidden_states.view(
30
+ batch_size, channel, height * width
31
+ ).transpose(1, 2)
32
+
33
+ batch_size, sequence_length, _ = (
34
+ hidden_states.shape
35
+ if encoder_hidden_states is None
36
+ else encoder_hidden_states.shape
37
+ )
38
+
39
+ if attention_mask is not None:
40
+ attention_mask = self.prepare_attention_mask(
41
+ attention_mask, sequence_length, batch_size
42
+ )
43
+ # scaled_dot_product_attention expects attention_mask shape to be
44
+ # (batch, heads, source_length, target_length)
45
+ attention_mask = attention_mask.view(
46
+ batch_size, self.heads, -1, attention_mask.shape[-1]
47
+ )
48
+
49
+ if self.group_norm is not None:
50
+ hidden_states = self.group_norm(
51
+ hidden_states.transpose(1, 2)
52
+ ).transpose(1, 2)
53
+
54
+ q = self.to_q(hidden_states)
55
+ is_self = encoder_hidden_states is None
56
+
57
+ if encoder_hidden_states is None:
58
+ encoder_hidden_states = hidden_states
59
+ elif self.norm_cross:
60
+ encoder_hidden_states = self.norm_encoder_hidden_states(
61
+ encoder_hidden_states
62
+ )
63
+
64
+ k = self.to_k(encoder_hidden_states)
65
+ v = self.to_v(encoder_hidden_states)
66
+
67
+ inner_dim = k.shape[-1]
68
+ head_dim = inner_dim // self.heads
69
+
70
+ q = q.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
71
+ k = k.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
72
+ v = v.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
73
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
74
+ # TODO: add support for attn.scale when we move to Torch 2.1
75
+ hidden_states = F.scaled_dot_product_attention(
76
+ q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
77
+ )
78
+ if is_self and controller.cur_self_layer in controller.self_layers:
79
+ cache.add(q, k, v, hidden_states)
80
+
81
+ hidden_states = hidden_states.transpose(1, 2).reshape(
82
+ batch_size, -1, self.heads * head_dim
83
+ )
84
+ hidden_states = hidden_states.to(q.dtype)
85
+
86
+ # linear proj
87
+ hidden_states = self.to_out[0](hidden_states)
88
+ # dropout
89
+ hidden_states = self.to_out[1](hidden_states)
90
+
91
+ if input_ndim == 4:
92
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
93
+ batch_size, channel, height, width
94
+ )
95
+ if self.residual_connection:
96
+ hidden_states = hidden_states + residual
97
+
98
+ hidden_states = hidden_states / self.rescale_output_factor
99
+
100
+ if is_self:
101
+ controller.cur_self_layer += 1
102
+
103
+ return hidden_states
104
+
105
+ return forward
106
+
107
+ def modify_forward(net, count):
108
+ for name, subnet in net.named_children():
109
+ if net.__class__.__name__ == "Attention": # spatial Transformer layer
110
+ net.forward = attn_forward(net)
111
+ return count + 1
112
+ elif hasattr(net, "children"):
113
+ count = modify_forward(subnet, count)
114
+ return count
115
+
116
+ cross_att_count = 0
117
+ for net_name, net in unet.named_children():
118
+ cross_att_count += modify_forward(net, 0)
119
+ controller.num_self_layers = cross_att_count // 2
120
+
121
+
122
+ def load_image(image_path, size=None, mode="RGB"):
123
+ img = Image.open(image_path).convert(mode)
124
+ if size is None:
125
+ width, height = img.size
126
+ new_width = (width // 64) * 64
127
+ new_height = (height // 64) * 64
128
+ size = (new_width, new_height)
129
+ img = img.resize(size, Image.BICUBIC)
130
+ return ToTensor()(img).unsqueeze(0)
131
+
132
+
133
+ def adain(source, target, eps=1e-6):
134
+ source_mean, source_std = torch.mean(source, dim=(2, 3), keepdim=True), torch.std(
135
+ source, dim=(2, 3), keepdim=True
136
+ )
137
+ target_mean, target_std = torch.mean(
138
+ target, dim=(0, 2, 3), keepdim=True
139
+ ), torch.std(target, dim=(0, 2, 3), keepdim=True)
140
+ normalized_source = (source - source_mean) / (source_std + eps)
141
+ transferred_source = normalized_source * target_std + target_mean
142
+
143
+ return transferred_source
144
+
145
+
146
+ class Controller:
147
+ def step(self):
148
+ self.cur_self_layer = 0
149
+
150
+ def __init__(self, self_layers=(0, 16)):
151
+ self.num_self_layers = -1
152
+ self.cur_self_layer = 0
153
+ self.self_layers = list(range(*self_layers))
154
+
155
+
156
+ class DataCache:
157
+ def __init__(self):
158
+ self.q = []
159
+ self.k = []
160
+ self.v = []
161
+ self.out = []
162
+
163
+ def clear(self):
164
+ self.q.clear()
165
+ self.k.clear()
166
+ self.v.clear()
167
+ self.out.clear()
168
+
169
+ def add(self, q, k, v, out):
170
+ self.q.append(q)
171
+ self.k.append(k)
172
+ self.v.append(v)
173
+ self.out.append(out)
174
+
175
+ def get(self):
176
+ return self.q.copy(), self.k.copy(), self.v.copy(), self.out.copy()
177
+
178
+
179
+
180
+ def show_image(path, title, display_height=3, title_fontsize=12):
181
+ img = Image.open(path)
182
+ img_width, img_height = img.size
183
+
184
+ aspect_ratio = img_width / img_height
185
+ display_width = display_height * aspect_ratio
186
+
187
+ plt.figure(figsize=(display_width, display_height))
188
+ plt.imshow(img)
189
+ plt.title(title,
190
+ fontsize=title_fontsize,
191
+ fontweight='bold',
192
+ pad=20)
193
+ plt.axis('off')
194
+ plt.tight_layout()
195
+ plt.show()
webui/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ from .tab_style_t2i import create_interface_style_t2i
3
+ from .tab_style_transfer import create_interface_style_transfer
4
+ from .tab_texture_synthesis import create_interface_texture_synthesis
5
+ from .runner import Runner
webui/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (420 Bytes). View file
 
webui/__pycache__/runner.cpython-310.pyc ADDED
Binary file (4.38 kB). View file
 
webui/__pycache__/tab_style_t2i.cpython-310.pyc ADDED
Binary file (2.5 kB). View file
 
webui/__pycache__/tab_style_transfer.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
webui/__pycache__/tab_texture_synthesis.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
webui/images/40.jpg ADDED
webui/images/42.jpg ADDED
webui/images/image_02_01.jpg ADDED
webui/images/lecun.png ADDED
webui/runner.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from diffusers import DDIMScheduler
4
+ from accelerate.utils import set_seed
5
+ from torchvision.transforms.functional import to_pil_image, to_tensor
6
+
7
+ from pipeline_sd import ADPipeline
8
+ from pipeline_sdxl import ADPipeline as ADXLPipeline
9
+ from utils import Controller
10
+
11
+ import os
12
+
13
+
14
+ class Runner:
15
+ def __init__(self):
16
+ self.sd15 = None
17
+ self.sdxl = None
18
+ self.loss_fn = torch.nn.L1Loss(reduction="mean")
19
+
20
+ def load_pipeline(self, model_path_or_name):
21
+
22
+ if 'xl' in model_path_or_name and self.sdxl is None:
23
+ scheduler = DDIMScheduler.from_pretrained(os.path.join('./checkpoints', model_path_or_name), subfolder="scheduler")
24
+ self.sdxl = ADXLPipeline.from_pretrained(os.path.join('./checkpoints', model_path_or_name), scheduler=scheduler, safety_checker=None)
25
+ self.sdxl.classifier = self.sdxl.unet
26
+ elif self.sd15 is None:
27
+ scheduler = DDIMScheduler.from_pretrained(os.path.join('./checkpoints', model_path_or_name), subfolder="scheduler")
28
+ self.sd15 = ADPipeline.from_pretrained(os.path.join('./checkpoints', model_path_or_name), scheduler=scheduler, safety_checker=None)
29
+ self.sd15.classifier = self.sd15.unet
30
+
31
+ def preprocecss(self, image: Image.Image, height=None, width=None):
32
+ if width is None or height is None:
33
+ width, height = image.size
34
+ new_width = (width // 64) * 64
35
+ new_height = (height // 64) * 64
36
+ size = (new_width, new_height)
37
+ image = image.resize(size, Image.BICUBIC)
38
+ return to_tensor(image).unsqueeze(0)
39
+
40
+ def run_style_transfer(self, content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model, **kwargs):
41
+ self.load_pipeline(model)
42
+
43
+ content_image = self.preprocecss(content_image)
44
+ style_image = self.preprocecss(style_image, height=512, width=512)
45
+
46
+ height, width = content_image.shape[-2:]
47
+ set_seed(seed)
48
+ controller = Controller(self_layers=(10, 16))
49
+ result = self.sd15.optimize(
50
+ lr=lr,
51
+ batch_size=1,
52
+ iters=1,
53
+ width=width,
54
+ height=height,
55
+ weight=content_weight,
56
+ controller=controller,
57
+ style_image=style_image,
58
+ content_image=content_image,
59
+ mixed_precision=mixed_precision,
60
+ num_inference_steps=num_steps,
61
+ enable_gradient_checkpoint=False,
62
+ )
63
+ output_image = to_pil_image(result[0])
64
+ del result
65
+ torch.cuda.empty_cache()
66
+ return [output_image]
67
+
68
+ def run_style_t2i_generation(self, style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model):
69
+ self.load_pipeline(model)
70
+
71
+ use_xl = 'xl' in model
72
+ height, width = (1024, 1024) if 'xl' in model else (512, 512)
73
+ style_image = self.preprocecss(style_image, height=height, width=width)
74
+
75
+ set_seed(seed)
76
+ self_layers = (64, 70) if use_xl else (10, 16)
77
+
78
+ controller = Controller(self_layers=self_layers)
79
+
80
+ pipeline = self.sdxl if use_xl else self.sd15
81
+ images = pipeline.sample(
82
+ controller=controller,
83
+ iters=iterations,
84
+ lr=lr,
85
+ adain=is_adain,
86
+ height=height,
87
+ width=width,
88
+ mixed_precision=mixed_precision,
89
+ style_image=style_image,
90
+ prompt=prompt,
91
+ negative_prompt=negative_prompt,
92
+ guidance_scale=guidance_scale,
93
+ num_inference_steps=num_steps,
94
+ num_images_per_prompt=num_images_per_prompt,
95
+ enable_gradient_checkpoint=False
96
+ )
97
+ output_images = [to_pil_image(image) for image in images]
98
+
99
+ del images
100
+ torch.cuda.empty_cache()
101
+ return output_images
102
+
103
+ def run_texture_synthesis(self, texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model):
104
+ self.load_pipeline(model)
105
+
106
+ texture_image = self.preprocecss(texture_image, height=512, width=512)
107
+
108
+ set_seed(seed)
109
+ controller = Controller(self_layers=(10, 16))
110
+
111
+ if synthesis_way == 'Sampling':
112
+ results = self.sd15.sample(
113
+ lr=lr,
114
+ adain=False,
115
+ iters=iterations,
116
+ width=width,
117
+ height=height,
118
+ weight=0.,
119
+ controller=controller,
120
+ style_image=texture_image,
121
+ content_image=None,
122
+ prompt="",
123
+ negative_prompt="",
124
+ mixed_precision=mixed_precision,
125
+ num_inference_steps=num_steps,
126
+ guidance_scale=1.,
127
+ num_images_per_prompt=num_images_per_prompt,
128
+ enable_gradient_checkpoint=False,
129
+ )
130
+ elif synthesis_way == 'MultiDiffusion':
131
+ results = self.sd15.panorama(
132
+ lr=lr,
133
+ iters=iterations,
134
+ width=width,
135
+ height=height,
136
+ weight=0.,
137
+ controller=controller,
138
+ style_image=texture_image,
139
+ content_image=None,
140
+ prompt="",
141
+ negative_prompt="",
142
+ stride=8,
143
+ view_batch_size=8,
144
+ mixed_precision=mixed_precision,
145
+ num_inference_steps=num_steps,
146
+ guidance_scale=1.,
147
+ num_images_per_prompt=num_images_per_prompt,
148
+ enable_gradient_checkpoint=False,
149
+ )
150
+ else:
151
+ raise ValueError
152
+
153
+ output_images = [to_pil_image(image) for image in results]
154
+ del results
155
+ torch.cuda.empty_cache()
156
+ return output_images
157
+
webui/tab_style_t2i.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import gradio as gr
4
+
5
+
6
+ def create_interface_style_t2i(runner):
7
+ with gr.Blocks():
8
+ with gr.Row():
9
+ gr.Markdown('1. Upload the style image and text your prompt.\n'
10
+ '2. Choose the generative model.\n'
11
+ '3. (Optional) Customize the configurations below as needed.\n'
12
+ '4. Cilck `Run` to start generation.')
13
+
14
+ with gr.Row():
15
+ with gr.Column():
16
+ style_image = gr.Image(label='Input Style Image', type='pil', interactive=True,
17
+ value=Image.open('examples/s1.jpg').convert('RGB') if os.path.exists('examples/s1.jpg') else None)
18
+ prompt = gr.Textbox(label='Prompt', value='A rocket')
19
+ negative_prompt = gr.Textbox(label='Negative Prompt', value='')
20
+
21
+ base_model_list = ['stable-diffusion-v1-5', 'stable-diffusion-xl-base-1.0']
22
+ model = gr.Radio(choices=base_model_list, label='Select a Base Model', value='stable-diffusion-xl-base-1.0')
23
+
24
+ run_button = gr.Button(value='Run')
25
+
26
+ gr.Examples(
27
+ [[Image.open('./webui/images/image_02_01.jpg').convert('RGB'), 'A rocket', 'stable-diffusion-xl-base-1.0']],
28
+ [style_image, prompt, model]
29
+ )
30
+
31
+ with gr.Column():
32
+ with gr.Accordion('Options', open=True):
33
+ guidance_scale = gr.Slider(label='Guidance Scale', minimum=1., maximum=30., value=7.5, step=0.1)
34
+ height = gr.Number(label='Height', value=1024, precision=0, minimum=2, maximum=4096)
35
+ width = gr.Number(label='Width', value=1024, precision=0, minimum=2, maximum=4096)
36
+ seed = gr.Number(label='Seed', value=2025, precision=0, minimum=0, maximum=2**31)
37
+ num_steps = gr.Slider(label='Number of Steps', minimum=1, maximum=1000, value=50, step=1)
38
+ iterations = gr.Slider(label='Iterations', minimum=0, maximum=10, value=2, step=1)
39
+ lr = gr.Slider(label='Learning Rate', minimum=0.01, maximum=0.5, value=0.015, step=0.001)
40
+ num_images_per_prompt = gr.Slider(label='Num Images Per Prompt', minimum=1, maximum=10, value=1, step=1)
41
+ mixed_precision = gr.Radio(choices=['bf16', 'no'], value='bf16', label='Mixed Precision')
42
+ is_adain = gr.Checkbox(label='Adain', value=True,)
43
+
44
+ with gr.Column():
45
+ gr.Markdown('#### Output Image:\n')
46
+ result_gallery = gr.Gallery(label='Output', elem_id='gallery', columns=2, height='auto', preview=True)
47
+
48
+ ips = [style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model]
49
+
50
+ run_button.click(fn=runner.run_style_t2i_generation, inputs=ips, outputs=[result_gallery])
51
+
webui/tab_style_transfer.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import gradio as gr
4
+
5
+
6
+ def create_interface_style_transfer(runner):
7
+ with gr.Blocks():
8
+ with gr.Row():
9
+ gr.Markdown('1. Upload the content and style images as inputs.\n'
10
+ '2. (Optional) Customize the configurations below as needed.\n'
11
+ '3. Cilck `Run` to start transfer.')
12
+
13
+ with gr.Row():
14
+ with gr.Column():
15
+ with gr.Row():
16
+ content_image = gr.Image(label='Input Content Image', type='pil', interactive=True,
17
+ value=Image.open('examples/c1.jpg').convert('RGB') if os.path.exists('examples/c1.jpg') else None)
18
+ style_image = gr.Image(label='Input Style Image', type='pil', interactive=True,
19
+ value=Image.open('examples/s1.jpg').convert('RGB') if os.path.exists('examples/s1.jpg') else None)
20
+
21
+ run_button = gr.Button(value='Run')
22
+
23
+ with gr.Accordion('Options', open=True):
24
+ seed = gr.Number(label='Seed', value=2025, precision=0, minimum=0, maximum=2**31)
25
+ num_steps = gr.Slider(label='Number of Steps', minimum=1, maximum=1000, value=200, step=1)
26
+ lr = gr.Slider(label='Learning Rate', minimum=0.01, maximum=0.5, value=0.05, step=0.01)
27
+ content_weight = gr.Slider(label='Content Weight', minimum=0., maximum=1., value=0.25, step=0.001)
28
+ mixed_precision = gr.Radio(choices=['bf16', 'no'], value='bf16', label='Mixed Precision')
29
+
30
+ base_model_list = ['stable-diffusion-v1-5',]
31
+ model = gr.Radio(choices=base_model_list, label='Select a Base Model', value='stable-diffusion-v1-5')
32
+
33
+ with gr.Column():
34
+ gr.Markdown('#### Output Image:\n')
35
+ result_gallery = gr.Gallery(label='Output', elem_id='gallery', columns=2, height='auto', preview=True)
36
+
37
+ gr.Examples(
38
+ [[Image.open('./webui/images/lecun.png').convert('RGB'), Image.open('./webui/images/40.jpg').convert('RGB'), 300, 0.23]],
39
+ [content_image, style_image, num_steps, content_weight]
40
+ )
41
+
42
+
43
+ ips = [content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model]
44
+
45
+ run_button.click(fn=runner.run_style_transfer, inputs=ips, outputs=[result_gallery])
webui/tab_texture_synthesis.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import gradio as gr
4
+
5
+
6
+ def create_interface_texture_synthesis(runner):
7
+ with gr.Blocks():
8
+ with gr.Row():
9
+ gr.Markdown('1. Upload the texture image as input.\n'
10
+ '2. (Optional) Customize the configurations below as needed.\n'
11
+ '3. Cilck `Run` to start synthesis.')
12
+
13
+ with gr.Row():
14
+ with gr.Column():
15
+ with gr.Row():
16
+ texture_image = gr.Image(label='Input Texture Image', type='pil', interactive=True,
17
+ value=Image.open('examples/s1.jpg').convert('RGB') if os.path.exists('examples/s1.jpg') else None)
18
+
19
+ run_button = gr.Button(value='Run')
20
+
21
+ with gr.Accordion('Options', open=True):
22
+ height = gr.Number(label='Height', value=512, precision=0, minimum=2, maximum=4096)
23
+ width = gr.Number(label='Width', value=1024, precision=0, minimum=2, maximum=4096)
24
+ seed = gr.Number(label='Seed', value=2025, precision=0, minimum=0, maximum=2**31)
25
+ num_steps = gr.Slider(label='Number of Steps', minimum=1, maximum=1000, value=200, step=1)
26
+ iterations = gr.Slider(label='Iterations', minimum=0, maximum=10, value=2, step=1)
27
+ lr = gr.Slider(label='Learning Rate', minimum=0.01, maximum=0.5, value=0.05, step=0.01)
28
+ mixed_precision = gr.Radio(choices=['bf16', 'no'], value='bf16', label='Mixed Precision')
29
+ num_images_per_prompt = gr.Slider(label='Num Images Per Prompt', minimum=1, maximum=10, value=1, step=1)
30
+
31
+ base_model_list = ['stable-diffusion-v1-5',]
32
+ model = gr.Radio(choices=base_model_list, label='Select a Base Model', value='stable-diffusion-v1-5')
33
+ synthesis_way = gr.Radio(['Sampling', 'MultiDiffusion'], label='Synthesis Way', value='MultiDiffusion')
34
+
35
+ with gr.Column():
36
+ gr.Markdown('#### Output Image:\n')
37
+ result_gallery = gr.Gallery(label='Output', elem_id='gallery', columns=2, height='auto', preview=True)
38
+
39
+ gr.Examples(
40
+ [[Image.open('./webui/images/42.jpg').convert('RGB'), 'MultiDiffusion', 512, 1024]],
41
+ [texture_image, synthesis_way, height, width]
42
+ )
43
+ ips = [texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model]
44
+
45
+ run_button.click(fn=runner.run_texture_synthesis, inputs=ips, outputs=[result_gallery])
46
+