Spaces:
Running
on
A10G
Running
on
A10G
Init demo.
Browse files- app.py +39 -7
- losses.py +56 -0
- pipeline_sd.py +680 -0
- pipeline_sdxl.py +403 -0
- requirements.txt +10 -0
- train_vae.py +87 -0
- utils.py +195 -0
- webui/__init__.py +5 -0
- webui/__pycache__/__init__.cpython-310.pyc +0 -0
- webui/__pycache__/runner.cpython-310.pyc +0 -0
- webui/__pycache__/tab_style_t2i.cpython-310.pyc +0 -0
- webui/__pycache__/tab_style_transfer.cpython-310.pyc +0 -0
- webui/__pycache__/tab_texture_synthesis.cpython-310.pyc +0 -0
- webui/images/40.jpg +0 -0
- webui/images/42.jpg +0 -0
- webui/images/image_02_01.jpg +0 -0
- webui/images/lecun.png +0 -0
- webui/runner.py +157 -0
- webui/tab_style_t2i.py +51 -0
- webui/tab_style_transfer.py +45 -0
- webui/tab_texture_synthesis.py +46 -0
app.py
CHANGED
@@ -1,7 +1,39 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from webui import (
|
3 |
+
create_interface_texture_synthesis,
|
4 |
+
create_interface_style_t2i,
|
5 |
+
create_interface_style_transfer,
|
6 |
+
Runner
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
import os
|
11 |
+
os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
runner = Runner()
|
16 |
+
|
17 |
+
with gr.Blocks(analytics_enabled=False,
|
18 |
+
title='Attention Distillation',
|
19 |
+
) as demo:
|
20 |
+
|
21 |
+
md_txt = "# Attention Distillation" \
|
22 |
+
"\nOfficial demo of the paper [Attention Distillation: A Unified Approach to Visual Characteristics Transfer](https://arxiv.org/abs/2502.20235)"
|
23 |
+
gr.Markdown(md_txt)
|
24 |
+
|
25 |
+
with gr.Tabs(selected='tab_style_transfer'):
|
26 |
+
with gr.TabItem("Style Transfer", id='tab_style_transfer'):
|
27 |
+
create_interface_style_transfer(runner=runner)
|
28 |
+
|
29 |
+
with gr.TabItem("Style-Specific Text-to-Image Generation", id='tab_style_t2i'):
|
30 |
+
create_interface_style_t2i(runner=runner)
|
31 |
+
|
32 |
+
with gr.TabItem("Texture Synthesis", id='tab_texture_syn'):
|
33 |
+
create_interface_texture_synthesis(runner=runner)
|
34 |
+
|
35 |
+
demo.launch(share=False, debug=False)
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
main()
|
losses.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
loss_fn = torch.nn.L1Loss()
|
8 |
+
|
9 |
+
|
10 |
+
def ad_loss(
|
11 |
+
q_list, ks_list, vs_list, self_out_list, scale=1, source_mask=None, target_mask=None
|
12 |
+
):
|
13 |
+
loss = 0
|
14 |
+
attn_mask = None
|
15 |
+
for q, ks, vs, self_out in zip(q_list, ks_list, vs_list, self_out_list):
|
16 |
+
if source_mask is not None and target_mask is not None:
|
17 |
+
w = h = int(np.sqrt(q.shape[2]))
|
18 |
+
mask_1 = torch.flatten(F.interpolate(source_mask, size=(h, w)))
|
19 |
+
mask_2 = torch.flatten(F.interpolate(target_mask, size=(h, w)))
|
20 |
+
attn_mask = mask_1.unsqueeze(0) == mask_2.unsqueeze(1)
|
21 |
+
attn_mask=attn_mask.to(q.device)
|
22 |
+
|
23 |
+
target_out = F.scaled_dot_product_attention(
|
24 |
+
q * scale,
|
25 |
+
torch.cat(torch.chunk(ks, ks.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
|
26 |
+
torch.cat(torch.chunk(vs, vs.shape[0]), 2).repeat(q.shape[0], 1, 1, 1),
|
27 |
+
attn_mask=attn_mask
|
28 |
+
)
|
29 |
+
loss += loss_fn(self_out, target_out.detach())
|
30 |
+
return loss
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def q_loss(q_list, qc_list):
|
35 |
+
loss = 0
|
36 |
+
for q, qc in zip(q_list, qc_list):
|
37 |
+
loss += loss_fn(q, qc.detach())
|
38 |
+
return loss
|
39 |
+
|
40 |
+
# weight = 200
|
41 |
+
def qk_loss(q_list, k_list, qc_list, kc_list):
|
42 |
+
loss = 0
|
43 |
+
for q, k, qc, kc in zip(q_list, k_list, qc_list, kc_list):
|
44 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
45 |
+
self_map = torch.softmax(q @ k.transpose(-2, -1) * scale_factor, dim=-1)
|
46 |
+
target_map = torch.softmax(qc @ kc.transpose(-2, -1) * scale_factor, dim=-1)
|
47 |
+
loss += loss_fn(self_map, target_map.detach())
|
48 |
+
return loss
|
49 |
+
|
50 |
+
# weight = 1
|
51 |
+
def qkv_loss(q_list, k_list, vc_list, c_out_list):
|
52 |
+
loss = 0
|
53 |
+
for q, k, vc, target_out in zip(q_list, k_list, vc_list, c_out_list):
|
54 |
+
self_out = F.scaled_dot_product_attention(q, k, vc)
|
55 |
+
loss += loss_fn(self_out, target_out.detach())
|
56 |
+
return loss
|
pipeline_sd.py
ADDED
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import utils
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from diffusers import StableDiffusionPipeline
|
10 |
+
from diffusers.image_processor import PipelineImageInput
|
11 |
+
from losses import *
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
class ADPipeline(StableDiffusionPipeline):
|
16 |
+
def freeze(self):
|
17 |
+
self.vae.requires_grad_(False)
|
18 |
+
self.unet.requires_grad_(False)
|
19 |
+
self.text_encoder.requires_grad_(False)
|
20 |
+
self.classifier.requires_grad_(False)
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def image2latent(self, image):
|
24 |
+
dtype = next(self.vae.parameters()).dtype
|
25 |
+
device = self._execution_device
|
26 |
+
image = image.to(device=device, dtype=dtype) * 2.0 - 1.0
|
27 |
+
latent = self.vae.encode(image)["latent_dist"].mean
|
28 |
+
latent = latent * self.vae.config.scaling_factor
|
29 |
+
return latent
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def latent2image(self, latent):
|
33 |
+
dtype = next(self.vae.parameters()).dtype
|
34 |
+
device = self._execution_device
|
35 |
+
latent = latent.to(device=device, dtype=dtype)
|
36 |
+
latent = latent / self.vae.config.scaling_factor
|
37 |
+
image = self.vae.decode(latent)[0]
|
38 |
+
return (image * 0.5 + 0.5).clamp(0, 1)
|
39 |
+
|
40 |
+
def init(self, enable_gradient_checkpoint):
|
41 |
+
self.freeze()
|
42 |
+
weight_dtype = torch.float32
|
43 |
+
if self.accelerator.mixed_precision == "fp16":
|
44 |
+
weight_dtype = torch.float16
|
45 |
+
elif self.accelerator.mixed_precision == "bf16":
|
46 |
+
weight_dtype = torch.bfloat16
|
47 |
+
|
48 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
49 |
+
self.unet.to(self.accelerator.device, dtype=weight_dtype)
|
50 |
+
self.vae.to(self.accelerator.device, dtype=weight_dtype)
|
51 |
+
self.text_encoder.to(self.accelerator.device, dtype=weight_dtype)
|
52 |
+
self.classifier.to(self.accelerator.device, dtype=weight_dtype)
|
53 |
+
self.classifier = self.accelerator.prepare(self.classifier)
|
54 |
+
if enable_gradient_checkpoint:
|
55 |
+
self.classifier.enable_gradient_checkpointing()
|
56 |
+
|
57 |
+
def sample(
|
58 |
+
self,
|
59 |
+
lr=0.05,
|
60 |
+
iters=1,
|
61 |
+
attn_scale=1,
|
62 |
+
adain=False,
|
63 |
+
weight=0.25,
|
64 |
+
controller=None,
|
65 |
+
style_image=None,
|
66 |
+
content_image=None,
|
67 |
+
mixed_precision="no",
|
68 |
+
start_time=999,
|
69 |
+
enable_gradient_checkpoint=False,
|
70 |
+
prompt: Union[str, List[str]] = None,
|
71 |
+
height: Optional[int] = None,
|
72 |
+
width: Optional[int] = None,
|
73 |
+
num_inference_steps: int = 50,
|
74 |
+
guidance_scale: float = 7.5,
|
75 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
76 |
+
num_images_per_prompt: Optional[int] = 1,
|
77 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
78 |
+
latents: Optional[torch.Tensor] = None,
|
79 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
80 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
81 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
82 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
83 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
84 |
+
guidance_rescale: float = 0.0,
|
85 |
+
clip_skip: Optional[int] = None,
|
86 |
+
**kwargs,
|
87 |
+
):
|
88 |
+
# 0. Default height and width to unet
|
89 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
90 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
91 |
+
self._guidance_scale = guidance_scale
|
92 |
+
self._guidance_rescale = guidance_rescale
|
93 |
+
self._clip_skip = clip_skip
|
94 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
95 |
+
self._interrupt = False
|
96 |
+
|
97 |
+
self.accelerator = Accelerator(
|
98 |
+
mixed_precision=mixed_precision, gradient_accumulation_steps=1
|
99 |
+
)
|
100 |
+
self.init(enable_gradient_checkpoint)
|
101 |
+
|
102 |
+
# 2. Define call parameters
|
103 |
+
if prompt is not None and isinstance(prompt, str):
|
104 |
+
batch_size = 1
|
105 |
+
elif prompt is not None and isinstance(prompt, list):
|
106 |
+
batch_size = len(prompt)
|
107 |
+
else:
|
108 |
+
batch_size = prompt_embeds.shape[0]
|
109 |
+
|
110 |
+
device = self._execution_device
|
111 |
+
|
112 |
+
# 3. Encode input prompt
|
113 |
+
lora_scale = (
|
114 |
+
self.cross_attention_kwargs.get("scale", None)
|
115 |
+
if self.cross_attention_kwargs is not None
|
116 |
+
else None
|
117 |
+
)
|
118 |
+
do_cfg = guidance_scale > 1.0
|
119 |
+
|
120 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
121 |
+
prompt,
|
122 |
+
device,
|
123 |
+
num_images_per_prompt,
|
124 |
+
do_cfg,
|
125 |
+
negative_prompt,
|
126 |
+
prompt_embeds=prompt_embeds,
|
127 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
128 |
+
lora_scale=lora_scale,
|
129 |
+
clip_skip=self.clip_skip,
|
130 |
+
)
|
131 |
+
|
132 |
+
# For classifier free guidance, we need to do two forward passes.
|
133 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
134 |
+
# to avoid doing two forward passes
|
135 |
+
if do_cfg:
|
136 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
137 |
+
|
138 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
139 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
140 |
+
ip_adapter_image,
|
141 |
+
ip_adapter_image_embeds,
|
142 |
+
device,
|
143 |
+
batch_size * num_images_per_prompt,
|
144 |
+
do_cfg,
|
145 |
+
)
|
146 |
+
|
147 |
+
# 5. Prepare latent variables
|
148 |
+
num_channels_latents = self.unet.config.in_channels
|
149 |
+
latents = self.prepare_latents(
|
150 |
+
batch_size * num_images_per_prompt,
|
151 |
+
num_channels_latents,
|
152 |
+
height,
|
153 |
+
width,
|
154 |
+
prompt_embeds.dtype,
|
155 |
+
device,
|
156 |
+
generator,
|
157 |
+
latents,
|
158 |
+
)
|
159 |
+
|
160 |
+
# 6.1 Add image embeds for IP-Adapter
|
161 |
+
added_cond_kwargs = (
|
162 |
+
{"image_embeds": image_embeds}
|
163 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
|
164 |
+
else None
|
165 |
+
)
|
166 |
+
|
167 |
+
# 6.2 Optionally get Guidance Scale Embedding
|
168 |
+
timestep_cond = None
|
169 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
170 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
171 |
+
batch_size * num_images_per_prompt
|
172 |
+
)
|
173 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
174 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
175 |
+
).to(device=device, dtype=latents.dtype)
|
176 |
+
|
177 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
178 |
+
timesteps = self.scheduler.timesteps
|
179 |
+
self.style_latent = self.image2latent(style_image)
|
180 |
+
if content_image is not None:
|
181 |
+
self.content_latent = self.image2latent(content_image)
|
182 |
+
else:
|
183 |
+
self.content_latent = None
|
184 |
+
null_embeds = self.encode_prompt("", device, 1, False)[0]
|
185 |
+
self.null_embeds = null_embeds
|
186 |
+
self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0])
|
187 |
+
self.null_embeds_for_style = torch.cat(
|
188 |
+
[null_embeds] * self.style_latent.shape[0]
|
189 |
+
)
|
190 |
+
|
191 |
+
self.adain = adain
|
192 |
+
self.attn_scale = attn_scale
|
193 |
+
self.cache = utils.DataCache()
|
194 |
+
self.controller = controller
|
195 |
+
utils.register_attn_control(
|
196 |
+
self.classifier, controller=self.controller, cache=self.cache
|
197 |
+
)
|
198 |
+
print("Total self attention layers of Unet: ", controller.num_self_layers)
|
199 |
+
print("Self attention layers for AD: ", controller.self_layers)
|
200 |
+
|
201 |
+
pbar = tqdm(timesteps, desc="Sample")
|
202 |
+
for i, t in enumerate(pbar):
|
203 |
+
with torch.no_grad():
|
204 |
+
# expand the latents if we are doing classifier free guidance
|
205 |
+
latent_model_input = torch.cat([latents] * 2) if do_cfg else latents
|
206 |
+
latent_model_input = self.scheduler.scale_model_input(
|
207 |
+
latent_model_input, t
|
208 |
+
)
|
209 |
+
# predict the noise residual
|
210 |
+
noise_pred = self.unet(
|
211 |
+
latent_model_input,
|
212 |
+
t,
|
213 |
+
encoder_hidden_states=prompt_embeds,
|
214 |
+
timestep_cond=timestep_cond,
|
215 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
216 |
+
added_cond_kwargs=added_cond_kwargs,
|
217 |
+
return_dict=False,
|
218 |
+
)[0]
|
219 |
+
|
220 |
+
# perform guidance
|
221 |
+
if do_cfg:
|
222 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
223 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
224 |
+
noise_pred_text - noise_pred_uncond
|
225 |
+
)
|
226 |
+
latents = self.scheduler.step(
|
227 |
+
noise_pred, t, latents, return_dict=False
|
228 |
+
)[0]
|
229 |
+
if iters > 0 and t < start_time:
|
230 |
+
latents = self.AD(latents, t, lr, iters, pbar, weight)
|
231 |
+
|
232 |
+
images = self.latent2image(latents)
|
233 |
+
# Offload all models
|
234 |
+
self.maybe_free_model_hooks()
|
235 |
+
return images
|
236 |
+
|
237 |
+
def optimize(
|
238 |
+
self,
|
239 |
+
latents=None,
|
240 |
+
attn_scale=1.0,
|
241 |
+
lr=0.05,
|
242 |
+
iters=1,
|
243 |
+
weight=0,
|
244 |
+
width=512,
|
245 |
+
height=512,
|
246 |
+
batch_size=1,
|
247 |
+
controller=None,
|
248 |
+
style_image=None,
|
249 |
+
content_image=None,
|
250 |
+
mixed_precision="no",
|
251 |
+
num_inference_steps=50,
|
252 |
+
enable_gradient_checkpoint=False,
|
253 |
+
source_mask=None,
|
254 |
+
target_mask=None,
|
255 |
+
):
|
256 |
+
height = height // self.vae_scale_factor
|
257 |
+
width = width // self.vae_scale_factor
|
258 |
+
|
259 |
+
self.accelerator = Accelerator(
|
260 |
+
mixed_precision=mixed_precision, gradient_accumulation_steps=1
|
261 |
+
)
|
262 |
+
self.init(enable_gradient_checkpoint)
|
263 |
+
|
264 |
+
style_latent = self.image2latent(style_image)
|
265 |
+
latents = torch.randn((batch_size, 4, height, width), device=self.device)
|
266 |
+
null_embeds = self.encode_prompt("", self.device, 1, False)[0]
|
267 |
+
null_embeds_for_latents = null_embeds.repeat(latents.shape[0], 1, 1)
|
268 |
+
null_embeds_for_style = null_embeds.repeat(style_latent.shape[0], 1, 1)
|
269 |
+
|
270 |
+
if content_image is not None:
|
271 |
+
content_latent = self.image2latent(content_image)
|
272 |
+
latents = torch.cat([content_latent.clone()] * batch_size)
|
273 |
+
null_embeds_for_content = null_embeds.repeat(content_latent.shape[0], 1, 1)
|
274 |
+
|
275 |
+
self.cache = utils.DataCache()
|
276 |
+
self.controller = controller
|
277 |
+
utils.register_attn_control(
|
278 |
+
self.classifier, controller=self.controller, cache=self.cache
|
279 |
+
)
|
280 |
+
print("Total self attention layers of Unet: ", controller.num_self_layers)
|
281 |
+
print("Self attention layers for AD: ", controller.self_layers)
|
282 |
+
|
283 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
284 |
+
timesteps = self.scheduler.timesteps
|
285 |
+
latents = latents.detach().float()
|
286 |
+
optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
|
287 |
+
optimizer = self.accelerator.prepare(optimizer)
|
288 |
+
pbar = tqdm(timesteps, desc="Optimize")
|
289 |
+
for i, t in enumerate(pbar):
|
290 |
+
# t = torch.tensor([1], device=self.device)
|
291 |
+
with torch.no_grad():
|
292 |
+
qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
|
293 |
+
style_latent,
|
294 |
+
t,
|
295 |
+
null_embeds_for_style,
|
296 |
+
)
|
297 |
+
if content_image is not None:
|
298 |
+
qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
|
299 |
+
content_latent,
|
300 |
+
t,
|
301 |
+
null_embeds_for_content,
|
302 |
+
)
|
303 |
+
for j in range(iters):
|
304 |
+
style_loss = 0
|
305 |
+
content_loss = 0
|
306 |
+
optimizer.zero_grad()
|
307 |
+
q_list, k_list, v_list, self_out_list = self.extract_feature(
|
308 |
+
latents,
|
309 |
+
t,
|
310 |
+
null_embeds_for_latents,
|
311 |
+
)
|
312 |
+
style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=attn_scale, source_mask=source_mask, target_mask=target_mask)
|
313 |
+
if content_image is not None:
|
314 |
+
content_loss = q_loss(q_list, qc_list)
|
315 |
+
# content_loss = qk_loss(q_list, k_list, qc_list, kc_list)
|
316 |
+
# content_loss = qkv_loss(q_list, k_list, vc_list, c_out_list)
|
317 |
+
loss = style_loss + content_loss * weight
|
318 |
+
self.accelerator.backward(loss)
|
319 |
+
optimizer.step()
|
320 |
+
pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
|
321 |
+
images = self.latent2image(latents)
|
322 |
+
# Offload all models
|
323 |
+
self.maybe_free_model_hooks()
|
324 |
+
return images
|
325 |
+
|
326 |
+
def panorama(
|
327 |
+
self,
|
328 |
+
lr=0.05,
|
329 |
+
iters=1,
|
330 |
+
attn_scale=1,
|
331 |
+
adain=False,
|
332 |
+
controller=None,
|
333 |
+
style_image=None,
|
334 |
+
mixed_precision="no",
|
335 |
+
enable_gradient_checkpoint=False,
|
336 |
+
prompt: Union[str, List[str]] = None,
|
337 |
+
height: Optional[int] = None,
|
338 |
+
width: Optional[int] = None,
|
339 |
+
num_inference_steps: int = 50,
|
340 |
+
guidance_scale: float = 1,
|
341 |
+
stride=8,
|
342 |
+
view_batch_size: int = 16,
|
343 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
344 |
+
num_images_per_prompt: Optional[int] = 1,
|
345 |
+
eta: float = 0.0,
|
346 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
347 |
+
latents: Optional[torch.Tensor] = None,
|
348 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
349 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
350 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
351 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
352 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
353 |
+
guidance_rescale: float = 0.0,
|
354 |
+
clip_skip: Optional[int] = None,
|
355 |
+
**kwargs,
|
356 |
+
):
|
357 |
+
|
358 |
+
# 0. Default height and width to unet
|
359 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
360 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
361 |
+
|
362 |
+
self._guidance_scale = guidance_scale
|
363 |
+
self._guidance_rescale = guidance_rescale
|
364 |
+
self._clip_skip = clip_skip
|
365 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
366 |
+
self._interrupt = False
|
367 |
+
|
368 |
+
self.accelerator = Accelerator(
|
369 |
+
mixed_precision=mixed_precision, gradient_accumulation_steps=1
|
370 |
+
)
|
371 |
+
self.init(enable_gradient_checkpoint)
|
372 |
+
|
373 |
+
# 2. Define call parameters
|
374 |
+
if prompt is not None and isinstance(prompt, str):
|
375 |
+
batch_size = 1
|
376 |
+
elif prompt is not None and isinstance(prompt, list):
|
377 |
+
batch_size = len(prompt)
|
378 |
+
else:
|
379 |
+
batch_size = prompt_embeds.shape[0]
|
380 |
+
|
381 |
+
device = self._execution_device
|
382 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
383 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
384 |
+
# corresponds to doing no classifier free guidance.
|
385 |
+
do_cfg = guidance_scale > 1.0
|
386 |
+
|
387 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
388 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
389 |
+
ip_adapter_image,
|
390 |
+
ip_adapter_image_embeds,
|
391 |
+
device,
|
392 |
+
batch_size * num_images_per_prompt,
|
393 |
+
self.do_classifier_free_guidance,
|
394 |
+
)
|
395 |
+
|
396 |
+
# 3. Encode input prompt
|
397 |
+
text_encoder_lora_scale = (
|
398 |
+
cross_attention_kwargs.get("scale", None)
|
399 |
+
if cross_attention_kwargs is not None
|
400 |
+
else None
|
401 |
+
)
|
402 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
403 |
+
prompt,
|
404 |
+
device,
|
405 |
+
num_images_per_prompt,
|
406 |
+
do_cfg,
|
407 |
+
negative_prompt,
|
408 |
+
prompt_embeds=prompt_embeds,
|
409 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
410 |
+
lora_scale=text_encoder_lora_scale,
|
411 |
+
clip_skip=clip_skip,
|
412 |
+
)
|
413 |
+
# For classifier free guidance, we need to do two forward passes.
|
414 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
415 |
+
# to avoid doing two forward passes
|
416 |
+
if do_cfg:
|
417 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
418 |
+
|
419 |
+
# 5. Prepare latent variables
|
420 |
+
num_channels_latents = self.unet.config.in_channels
|
421 |
+
latents = self.prepare_latents(
|
422 |
+
batch_size * num_images_per_prompt,
|
423 |
+
num_channels_latents,
|
424 |
+
height,
|
425 |
+
width,
|
426 |
+
prompt_embeds.dtype,
|
427 |
+
device,
|
428 |
+
generator,
|
429 |
+
latents,
|
430 |
+
)
|
431 |
+
|
432 |
+
# 6. Define panorama grid and initialize views for synthesis.
|
433 |
+
# prepare batch grid
|
434 |
+
views = self.get_views_(height, width, window_size=64, stride=stride)
|
435 |
+
views_batch = [
|
436 |
+
views[i : i + view_batch_size]
|
437 |
+
for i in range(0, len(views), view_batch_size)
|
438 |
+
]
|
439 |
+
print(len(views), len(views_batch), views_batch)
|
440 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
441 |
+
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(
|
442 |
+
views_batch
|
443 |
+
)
|
444 |
+
count = torch.zeros_like(latents)
|
445 |
+
value = torch.zeros_like(latents)
|
446 |
+
|
447 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
448 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
449 |
+
|
450 |
+
# 7.1 Add image embeds for IP-Adapter
|
451 |
+
added_cond_kwargs = (
|
452 |
+
{"image_embeds": image_embeds}
|
453 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
454 |
+
else None
|
455 |
+
)
|
456 |
+
|
457 |
+
# 7.2 Optionally get Guidance Scale Embedding
|
458 |
+
timestep_cond = None
|
459 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
460 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
461 |
+
batch_size * num_images_per_prompt
|
462 |
+
)
|
463 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
464 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
465 |
+
).to(device=device, dtype=latents.dtype)
|
466 |
+
|
467 |
+
# 8. Denoising loop
|
468 |
+
# Each denoising step also includes refinement of the latents with respect to the
|
469 |
+
# views.
|
470 |
+
|
471 |
+
timesteps = self.scheduler.timesteps
|
472 |
+
self.style_latent = self.image2latent(style_image)
|
473 |
+
self.content_latent = None
|
474 |
+
null_embeds = self.encode_prompt("", device, 1, False)[0]
|
475 |
+
self.null_embeds = null_embeds
|
476 |
+
self.null_embeds_for_latents = torch.cat([null_embeds] * latents.shape[0])
|
477 |
+
self.null_embeds_for_style = torch.cat(
|
478 |
+
[null_embeds] * self.style_latent.shape[0]
|
479 |
+
)
|
480 |
+
self.adain = adain
|
481 |
+
self.attn_scale = attn_scale
|
482 |
+
self.cache = utils.DataCache()
|
483 |
+
self.controller = controller
|
484 |
+
utils.register_attn_control(
|
485 |
+
self.classifier, controller=self.controller, cache=self.cache
|
486 |
+
)
|
487 |
+
print("Total self attention layers of Unet: ", controller.num_self_layers)
|
488 |
+
print("Self attention layers for AD: ", controller.self_layers)
|
489 |
+
|
490 |
+
pbar = tqdm(timesteps, desc="Sample")
|
491 |
+
for i, t in enumerate(pbar):
|
492 |
+
count.zero_()
|
493 |
+
value.zero_()
|
494 |
+
# generate views
|
495 |
+
# Here, we iterate through different spatial crops of the latents and denoise them. These
|
496 |
+
# denoised (latent) crops are then averaged to produce the final latent
|
497 |
+
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
|
498 |
+
# MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
|
499 |
+
# Batch views denoise
|
500 |
+
for j, batch_view in enumerate(views_batch):
|
501 |
+
vb_size = len(batch_view)
|
502 |
+
# get the latents corresponding to the current view coordinates
|
503 |
+
latents_for_view = torch.cat(
|
504 |
+
[
|
505 |
+
latents[:, :, h_start:h_end, w_start:w_end]
|
506 |
+
for h_start, h_end, w_start, w_end in batch_view
|
507 |
+
]
|
508 |
+
)
|
509 |
+
# rematch block's scheduler status
|
510 |
+
self.scheduler.__dict__.update(views_scheduler_status[j])
|
511 |
+
|
512 |
+
# expand the latents if we are doing classifier free guidance
|
513 |
+
latent_model_input = (
|
514 |
+
latents_for_view.repeat_interleave(2, dim=0)
|
515 |
+
if do_cfg
|
516 |
+
else latents_for_view
|
517 |
+
)
|
518 |
+
|
519 |
+
latent_model_input = self.scheduler.scale_model_input(
|
520 |
+
latent_model_input, t
|
521 |
+
)
|
522 |
+
|
523 |
+
# repeat prompt_embeds for batch
|
524 |
+
prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
|
525 |
+
|
526 |
+
# predict the noise residual
|
527 |
+
with torch.no_grad():
|
528 |
+
noise_pred = self.unet(
|
529 |
+
latent_model_input,
|
530 |
+
t,
|
531 |
+
encoder_hidden_states=prompt_embeds_input,
|
532 |
+
timestep_cond=timestep_cond,
|
533 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
534 |
+
added_cond_kwargs=added_cond_kwargs,
|
535 |
+
).sample
|
536 |
+
|
537 |
+
# perform guidance
|
538 |
+
if do_cfg:
|
539 |
+
noise_pred_uncond, noise_pred_text = (
|
540 |
+
noise_pred[::2],
|
541 |
+
noise_pred[1::2],
|
542 |
+
)
|
543 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
544 |
+
noise_pred_text - noise_pred_uncond
|
545 |
+
)
|
546 |
+
|
547 |
+
# compute the previous noisy sample x_t -> x_t-1
|
548 |
+
latents_denoised_batch = self.scheduler.step(
|
549 |
+
noise_pred, t, latents_for_view, **extra_step_kwargs
|
550 |
+
).prev_sample
|
551 |
+
if iters > 0:
|
552 |
+
self.null_embeds_for_latents = torch.cat(
|
553 |
+
[self.null_embeds] * noise_pred.shape[0]
|
554 |
+
)
|
555 |
+
latents_denoised_batch = self.AD(
|
556 |
+
latents_denoised_batch, t, lr, iters, pbar
|
557 |
+
)
|
558 |
+
# save views scheduler status after sample
|
559 |
+
views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__)
|
560 |
+
|
561 |
+
# extract value from batch
|
562 |
+
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
|
563 |
+
latents_denoised_batch.chunk(vb_size), batch_view
|
564 |
+
):
|
565 |
+
|
566 |
+
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
567 |
+
count[:, :, h_start:h_end, w_start:w_end] += 1
|
568 |
+
|
569 |
+
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
570 |
+
latents = torch.where(count > 0, value / count, value)
|
571 |
+
|
572 |
+
images = self.latent2image(latents)
|
573 |
+
# Offload all models
|
574 |
+
self.maybe_free_model_hooks()
|
575 |
+
return images
|
576 |
+
|
577 |
+
def AD(self, latents, t, lr, iters, pbar, weight=0):
|
578 |
+
t = max(
|
579 |
+
t
|
580 |
+
- self.scheduler.config.num_train_timesteps
|
581 |
+
// self.scheduler.num_inference_steps,
|
582 |
+
torch.tensor([0], device=self.device),
|
583 |
+
)
|
584 |
+
if self.adain:
|
585 |
+
noise = torch.randn_like(self.style_latent)
|
586 |
+
style_latent = self.scheduler.add_noise(self.style_latent, noise, t)
|
587 |
+
latents = utils.adain(latents, style_latent)
|
588 |
+
|
589 |
+
with torch.no_grad():
|
590 |
+
qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
|
591 |
+
self.style_latent,
|
592 |
+
t,
|
593 |
+
self.null_embeds_for_style,
|
594 |
+
add_noise=True,
|
595 |
+
)
|
596 |
+
if self.content_latent is not None:
|
597 |
+
qc_list, kc_list, vc_list, c_out_list = self.extract_feature(
|
598 |
+
self.content_latent,
|
599 |
+
t,
|
600 |
+
self.null_embeds,
|
601 |
+
add_noise=True,
|
602 |
+
)
|
603 |
+
|
604 |
+
latents = latents.detach()
|
605 |
+
optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
|
606 |
+
optimizer = self.accelerator.prepare(optimizer)
|
607 |
+
|
608 |
+
for j in range(iters):
|
609 |
+
style_loss = 0
|
610 |
+
content_loss = 0
|
611 |
+
optimizer.zero_grad()
|
612 |
+
q_list, k_list, v_list, self_out_list = self.extract_feature(
|
613 |
+
latents,
|
614 |
+
t,
|
615 |
+
self.null_embeds_for_latents,
|
616 |
+
add_noise=False,
|
617 |
+
)
|
618 |
+
style_loss = ad_loss(q_list, ks_list, vs_list, self_out_list, scale=self.attn_scale)
|
619 |
+
if self.content_latent is not None:
|
620 |
+
content_loss = q_loss(q_list, qc_list)
|
621 |
+
# content_loss = qk_loss(q_list, k_list, qc_list, kc_list)
|
622 |
+
# content_loss = qkv_loss(q_list, k_list, vc_list, c_out_list)
|
623 |
+
loss = style_loss + content_loss * weight
|
624 |
+
self.accelerator.backward(loss)
|
625 |
+
optimizer.step()
|
626 |
+
|
627 |
+
pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
|
628 |
+
latents = latents.detach()
|
629 |
+
return latents
|
630 |
+
|
631 |
+
def extract_feature(
|
632 |
+
self,
|
633 |
+
latent,
|
634 |
+
t,
|
635 |
+
embeds,
|
636 |
+
add_noise=False,
|
637 |
+
):
|
638 |
+
self.cache.clear()
|
639 |
+
self.controller.step()
|
640 |
+
if add_noise:
|
641 |
+
noise = torch.randn_like(latent)
|
642 |
+
latent_ = self.scheduler.add_noise(latent, noise, t)
|
643 |
+
else:
|
644 |
+
latent_ = latent
|
645 |
+
_ = self.classifier(latent_, t, embeds)[0]
|
646 |
+
return self.cache.get()
|
647 |
+
|
648 |
+
def get_views_(
|
649 |
+
self,
|
650 |
+
panorama_height: int,
|
651 |
+
panorama_width: int,
|
652 |
+
window_size: int = 64,
|
653 |
+
stride: int = 8,
|
654 |
+
) -> List[Tuple[int, int, int, int]]:
|
655 |
+
panorama_height //= 8
|
656 |
+
panorama_width //= 8
|
657 |
+
|
658 |
+
num_blocks_height = (
|
659 |
+
math.ceil((panorama_height - window_size) / stride) + 1
|
660 |
+
if panorama_height > window_size
|
661 |
+
else 1
|
662 |
+
)
|
663 |
+
num_blocks_width = (
|
664 |
+
math.ceil((panorama_width - window_size) / stride) + 1
|
665 |
+
if panorama_width > window_size
|
666 |
+
else 1
|
667 |
+
)
|
668 |
+
|
669 |
+
views = []
|
670 |
+
for i in range(int(num_blocks_height)):
|
671 |
+
for j in range(int(num_blocks_width)):
|
672 |
+
h_start = int(min(i * stride, panorama_height - window_size))
|
673 |
+
w_start = int(min(j * stride, panorama_width - window_size))
|
674 |
+
|
675 |
+
h_end = h_start + window_size
|
676 |
+
w_end = w_start + window_size
|
677 |
+
|
678 |
+
views.append((h_start, h_end, w_start, w_end))
|
679 |
+
|
680 |
+
return views
|
pipeline_sdxl.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import utils
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from accelerate.utils import (
|
9 |
+
DistributedDataParallelKwargs,
|
10 |
+
ProjectConfiguration,
|
11 |
+
set_seed,
|
12 |
+
)
|
13 |
+
from diffusers import StableDiffusionXLPipeline
|
14 |
+
from diffusers.image_processor import PipelineImageInput
|
15 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
16 |
+
from losses import *
|
17 |
+
|
18 |
+
# from peft import LoraConfig, set_peft_model_state_dict
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
|
22 |
+
class ADPipeline(StableDiffusionXLPipeline):
|
23 |
+
def freeze(self):
|
24 |
+
self.unet.requires_grad_(False)
|
25 |
+
self.text_encoder.requires_grad_(False)
|
26 |
+
self.text_encoder_2.requires_grad_(False)
|
27 |
+
self.vae.requires_grad_(False)
|
28 |
+
self.classifier.requires_grad_(False)
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def image2latent(self, image):
|
32 |
+
dtype = next(self.vae.parameters()).dtype
|
33 |
+
device = self._execution_device
|
34 |
+
image = image.to(device=device, dtype=dtype) * 2.0 - 1.0
|
35 |
+
latent = self.vae.encode(image)["latent_dist"].mean
|
36 |
+
latent = latent * self.vae.config.scaling_factor
|
37 |
+
return latent
|
38 |
+
|
39 |
+
@torch.no_grad()
|
40 |
+
def latent2image(self, latent):
|
41 |
+
dtype = next(self.vae.parameters()).dtype
|
42 |
+
device = self._execution_device
|
43 |
+
latent = latent.to(device=device, dtype=dtype)
|
44 |
+
latent = latent / self.vae.config.scaling_factor
|
45 |
+
image = self.vae.decode(latent)[0]
|
46 |
+
return (image * 0.5 + 0.5).clamp(0, 1)
|
47 |
+
|
48 |
+
def init(self, enable_gradient_checkpoint):
|
49 |
+
self.freeze()
|
50 |
+
weight_dtype = torch.float32
|
51 |
+
if self.accelerator.mixed_precision == "fp16":
|
52 |
+
weight_dtype = torch.float16
|
53 |
+
elif self.accelerator.mixed_precision == "bf16":
|
54 |
+
weight_dtype = torch.bfloat16
|
55 |
+
|
56 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
57 |
+
self.unet.to(self.accelerator.device, dtype=weight_dtype)
|
58 |
+
self.vae.to(self.accelerator.device, dtype=weight_dtype)
|
59 |
+
self.text_encoder.to(self.accelerator.device, dtype=weight_dtype)
|
60 |
+
self.text_encoder_2.to(self.accelerator.device, dtype=weight_dtype)
|
61 |
+
self.classifier.to(self.accelerator.device, dtype=weight_dtype)
|
62 |
+
self.classifier = self.accelerator.prepare(self.classifier)
|
63 |
+
if enable_gradient_checkpoint:
|
64 |
+
self.classifier.enable_gradient_checkpointing()
|
65 |
+
# self.classifier.train()
|
66 |
+
|
67 |
+
|
68 |
+
def sample(
|
69 |
+
self,
|
70 |
+
lr=0.05,
|
71 |
+
iters=1,
|
72 |
+
adain=True,
|
73 |
+
controller=None,
|
74 |
+
style_image=None,
|
75 |
+
mixed_precision="no",
|
76 |
+
init_from_style=False,
|
77 |
+
start_time=999,
|
78 |
+
prompt: Union[str, List[str]] = None,
|
79 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
80 |
+
height: Optional[int] = None,
|
81 |
+
width: Optional[int] = None,
|
82 |
+
num_inference_steps: int = 50,
|
83 |
+
denoising_end: Optional[float] = None,
|
84 |
+
guidance_scale: float = 5.0,
|
85 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
86 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
87 |
+
num_images_per_prompt: Optional[int] = 1,
|
88 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
89 |
+
latents: Optional[torch.Tensor] = None,
|
90 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
91 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
92 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
93 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
94 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
95 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
96 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
97 |
+
guidance_rescale: float = 0.0,
|
98 |
+
original_size: Optional[Tuple[int, int]] = None,
|
99 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
100 |
+
target_size: Optional[Tuple[int, int]] = None,
|
101 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
102 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
103 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
104 |
+
clip_skip: Optional[int] = None,
|
105 |
+
enable_gradient_checkpoint=False,
|
106 |
+
**kwargs,
|
107 |
+
):
|
108 |
+
# 0. Default height and width to unet
|
109 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
110 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
111 |
+
|
112 |
+
original_size = original_size or (height, width)
|
113 |
+
target_size = target_size or (height, width)
|
114 |
+
self._guidance_scale = guidance_scale
|
115 |
+
self._guidance_rescale = guidance_rescale
|
116 |
+
self._clip_skip = clip_skip
|
117 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
118 |
+
self._denoising_end = denoising_end
|
119 |
+
self._interrupt = False
|
120 |
+
|
121 |
+
self.accelerator = Accelerator(
|
122 |
+
mixed_precision=mixed_precision, gradient_accumulation_steps=1
|
123 |
+
)
|
124 |
+
self.init(enable_gradient_checkpoint)
|
125 |
+
|
126 |
+
# 2. Define call parameters
|
127 |
+
if prompt is not None and isinstance(prompt, str):
|
128 |
+
batch_size = 1
|
129 |
+
elif prompt is not None and isinstance(prompt, list):
|
130 |
+
batch_size = len(prompt)
|
131 |
+
else:
|
132 |
+
batch_size = prompt_embeds.shape[0]
|
133 |
+
|
134 |
+
device = self._execution_device
|
135 |
+
|
136 |
+
# 3. Encode input prompt
|
137 |
+
lora_scale = (
|
138 |
+
self.cross_attention_kwargs.get("scale", None)
|
139 |
+
if self.cross_attention_kwargs is not None
|
140 |
+
else None
|
141 |
+
)
|
142 |
+
|
143 |
+
(
|
144 |
+
prompt_embeds,
|
145 |
+
negative_prompt_embeds,
|
146 |
+
pooled_prompt_embeds,
|
147 |
+
negative_pooled_prompt_embeds,
|
148 |
+
) = self.encode_prompt(
|
149 |
+
prompt=prompt,
|
150 |
+
prompt_2=prompt_2,
|
151 |
+
device=device,
|
152 |
+
num_images_per_prompt=num_images_per_prompt,
|
153 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
154 |
+
negative_prompt=negative_prompt,
|
155 |
+
negative_prompt_2=negative_prompt_2,
|
156 |
+
prompt_embeds=prompt_embeds,
|
157 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
158 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
159 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
160 |
+
lora_scale=lora_scale,
|
161 |
+
clip_skip=self.clip_skip,
|
162 |
+
)
|
163 |
+
|
164 |
+
# 5. Prepare latent variables
|
165 |
+
num_channels_latents = self.unet.config.in_channels
|
166 |
+
latents = self.prepare_latents(
|
167 |
+
batch_size * num_images_per_prompt,
|
168 |
+
num_channels_latents,
|
169 |
+
height,
|
170 |
+
width,
|
171 |
+
prompt_embeds.dtype,
|
172 |
+
device,
|
173 |
+
generator,
|
174 |
+
latents,
|
175 |
+
)
|
176 |
+
|
177 |
+
# 7. Prepare added time ids & embeddings
|
178 |
+
add_text_embeds = pooled_prompt_embeds
|
179 |
+
if self.text_encoder_2 is None:
|
180 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
181 |
+
else:
|
182 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
183 |
+
|
184 |
+
add_time_ids = self._get_add_time_ids(
|
185 |
+
original_size,
|
186 |
+
crops_coords_top_left,
|
187 |
+
target_size,
|
188 |
+
dtype=prompt_embeds.dtype,
|
189 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
190 |
+
)
|
191 |
+
null_add_time_ids = add_time_ids.to(device)
|
192 |
+
if negative_original_size is not None and negative_target_size is not None:
|
193 |
+
negative_add_time_ids = self._get_add_time_ids(
|
194 |
+
negative_original_size,
|
195 |
+
negative_crops_coords_top_left,
|
196 |
+
negative_target_size,
|
197 |
+
dtype=prompt_embeds.dtype,
|
198 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
negative_add_time_ids = add_time_ids
|
202 |
+
|
203 |
+
if self.do_classifier_free_guidance:
|
204 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
205 |
+
add_text_embeds = torch.cat(
|
206 |
+
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
207 |
+
)
|
208 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
209 |
+
|
210 |
+
prompt_embeds = prompt_embeds.to(device)
|
211 |
+
add_text_embeds = add_text_embeds.to(device)
|
212 |
+
add_time_ids = add_time_ids.to(device).repeat(
|
213 |
+
batch_size * num_images_per_prompt, 1
|
214 |
+
)
|
215 |
+
|
216 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
217 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
218 |
+
ip_adapter_image,
|
219 |
+
ip_adapter_image_embeds,
|
220 |
+
device,
|
221 |
+
batch_size * num_images_per_prompt,
|
222 |
+
self.do_classifier_free_guidance,
|
223 |
+
)
|
224 |
+
# 8.1 Apply denoising_end
|
225 |
+
if (
|
226 |
+
self.denoising_end is not None
|
227 |
+
and isinstance(self.denoising_end, float)
|
228 |
+
and self.denoising_end > 0
|
229 |
+
and self.denoising_end < 1
|
230 |
+
):
|
231 |
+
discrete_timestep_cutoff = int(
|
232 |
+
round(
|
233 |
+
self.scheduler.config.num_train_timesteps
|
234 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
235 |
+
)
|
236 |
+
)
|
237 |
+
num_inference_steps = len(
|
238 |
+
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
|
239 |
+
)
|
240 |
+
timesteps = timesteps[:num_inference_steps]
|
241 |
+
|
242 |
+
# 9. Optionally get Guidance Scale Embedding
|
243 |
+
timestep_cond = None
|
244 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
245 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
246 |
+
batch_size * num_images_per_prompt
|
247 |
+
)
|
248 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
249 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
250 |
+
).to(device=device, dtype=latents.dtype)
|
251 |
+
self.timestep_cond = timestep_cond
|
252 |
+
(null_embeds, _, null_pooled_embeds, _) = self.encode_prompt("", device=device)
|
253 |
+
|
254 |
+
added_cond_kwargs = {
|
255 |
+
"text_embeds": add_text_embeds,
|
256 |
+
"time_ids": add_time_ids
|
257 |
+
}
|
258 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
259 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
260 |
+
|
261 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
262 |
+
|
263 |
+
timesteps = self.scheduler.timesteps
|
264 |
+
style_latent = self.image2latent(style_image)
|
265 |
+
if init_from_style:
|
266 |
+
latents = torch.cat([style_latent] * latents.shape[0])
|
267 |
+
noise = torch.randn_like(latents)
|
268 |
+
latents = self.scheduler.add_noise(
|
269 |
+
latents,
|
270 |
+
noise,
|
271 |
+
torch.tensor([999]),
|
272 |
+
)
|
273 |
+
|
274 |
+
self.style_latent = style_latent
|
275 |
+
self.null_embeds_for_latents = torch.cat([null_embeds] * (latents.shape[0]))
|
276 |
+
self.null_embeds_for_style = torch.cat([null_embeds] * style_latent.shape[0])
|
277 |
+
self.null_added_cond_kwargs_for_latents = {
|
278 |
+
"text_embeds": torch.cat([null_pooled_embeds] * (latents.shape[0])),
|
279 |
+
"time_ids": torch.cat([null_add_time_ids] * (latents.shape[0])),
|
280 |
+
}
|
281 |
+
self.null_added_cond_kwargs_for_style = {
|
282 |
+
"text_embeds": torch.cat([null_pooled_embeds] * style_latent.shape[0]),
|
283 |
+
"time_ids": torch.cat([null_add_time_ids] * style_latent.shape[0]),
|
284 |
+
}
|
285 |
+
self.adain = adain
|
286 |
+
self.cache = utils.DataCache()
|
287 |
+
self.controller = controller
|
288 |
+
utils.register_attn_control(
|
289 |
+
self.classifier, controller=controller, cache=self.cache
|
290 |
+
)
|
291 |
+
print("Total self attention layers of Unet: ", controller.num_self_layers)
|
292 |
+
print("Self attention layers for AD: ", controller.self_layers)
|
293 |
+
|
294 |
+
pbar = tqdm(timesteps, desc="Sample")
|
295 |
+
for i, t in enumerate(pbar):
|
296 |
+
with torch.no_grad():
|
297 |
+
# expand the latents if we are doing classifier free guidance
|
298 |
+
latent_model_input = (
|
299 |
+
torch.cat([latents] * 2)
|
300 |
+
if self.do_classifier_free_guidance
|
301 |
+
else latents
|
302 |
+
)
|
303 |
+
|
304 |
+
# predict the noise residual
|
305 |
+
noise_pred = self.unet(
|
306 |
+
latent_model_input,
|
307 |
+
t,
|
308 |
+
encoder_hidden_states=prompt_embeds,
|
309 |
+
timestep_cond=timestep_cond,
|
310 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
311 |
+
added_cond_kwargs=added_cond_kwargs,
|
312 |
+
return_dict=False,
|
313 |
+
)[0]
|
314 |
+
|
315 |
+
# perform guidance
|
316 |
+
if self.do_classifier_free_guidance:
|
317 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
318 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
319 |
+
noise_pred_text - noise_pred_uncond
|
320 |
+
)
|
321 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
322 |
+
|
323 |
+
if iters > 0 and t < start_time:
|
324 |
+
latents = self.AD(latents, t, lr, iters, pbar)
|
325 |
+
|
326 |
+
|
327 |
+
# Offload all models
|
328 |
+
# self.enable_model_cpu_offload()
|
329 |
+
images = self.latent2image(latents)
|
330 |
+
self.maybe_free_model_hooks()
|
331 |
+
return images
|
332 |
+
|
333 |
+
def AD(self, latents, t, lr, iters, pbar):
|
334 |
+
t = max(
|
335 |
+
t
|
336 |
+
- self.scheduler.config.num_train_timesteps
|
337 |
+
// self.scheduler.num_inference_steps,
|
338 |
+
torch.tensor([0], device=self.device),
|
339 |
+
)
|
340 |
+
|
341 |
+
if self.adain:
|
342 |
+
noise = torch.randn_like(self.style_latent)
|
343 |
+
style_latent = self.scheduler.add_noise(self.style_latent, noise, t)
|
344 |
+
latents = utils.adain(latents, style_latent)
|
345 |
+
|
346 |
+
with torch.no_grad():
|
347 |
+
qs_list, ks_list, vs_list, s_out_list = self.extract_feature(
|
348 |
+
self.style_latent,
|
349 |
+
t,
|
350 |
+
self.null_embeds_for_style,
|
351 |
+
self.timestep_cond,
|
352 |
+
self.null_added_cond_kwargs_for_style,
|
353 |
+
add_noise=True,
|
354 |
+
)
|
355 |
+
# latents = latents.to(dtype=torch.float32)
|
356 |
+
latents = latents.detach()
|
357 |
+
optimizer = torch.optim.Adam([latents.requires_grad_()], lr=lr)
|
358 |
+
optimizer, latents = self.accelerator.prepare(optimizer, latents)
|
359 |
+
|
360 |
+
for j in range(iters):
|
361 |
+
optimizer.zero_grad()
|
362 |
+
q_list, k_list, v_list, self_out_list = self.extract_feature(
|
363 |
+
latents,
|
364 |
+
t,
|
365 |
+
self.null_embeds_for_latents,
|
366 |
+
self.timestep_cond,
|
367 |
+
self.null_added_cond_kwargs_for_latents,
|
368 |
+
add_noise=False,
|
369 |
+
)
|
370 |
+
|
371 |
+
loss = ad_loss(q_list, ks_list, vs_list, self_out_list)
|
372 |
+
self.accelerator.backward(loss)
|
373 |
+
optimizer.step()
|
374 |
+
|
375 |
+
pbar.set_postfix(loss=loss.item(), time=t.item(), iter=j)
|
376 |
+
latents = latents.detach()
|
377 |
+
return latents
|
378 |
+
|
379 |
+
def extract_feature(
|
380 |
+
self,
|
381 |
+
latent,
|
382 |
+
t,
|
383 |
+
encoder_hidden_states,
|
384 |
+
timestep_cond,
|
385 |
+
added_cond_kwargs,
|
386 |
+
add_noise=False,
|
387 |
+
):
|
388 |
+
self.cache.clear()
|
389 |
+
self.controller.step()
|
390 |
+
if add_noise:
|
391 |
+
noise = torch.randn_like(latent)
|
392 |
+
latent_ = self.scheduler.add_noise(latent, noise, t)
|
393 |
+
else:
|
394 |
+
latent_ = latent
|
395 |
+
self.classifier(
|
396 |
+
latent_,
|
397 |
+
t,
|
398 |
+
encoder_hidden_states=encoder_hidden_states,
|
399 |
+
timestep_cond=timestep_cond,
|
400 |
+
added_cond_kwargs=added_cond_kwargs,
|
401 |
+
return_dict=False,
|
402 |
+
)[0]
|
403 |
+
return self.cache.get()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers
|
2 |
+
torch>=2.0.0
|
3 |
+
torchvision
|
4 |
+
transformers
|
5 |
+
accelerate
|
6 |
+
safetensors
|
7 |
+
spaces
|
8 |
+
huggingface-hub
|
9 |
+
gradio
|
10 |
+
matplotlib
|
train_vae.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import AutoencoderKL
|
6 |
+
from torch import nn
|
7 |
+
from torch.optim import Adam
|
8 |
+
from utils import load_image, save_image
|
9 |
+
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
|
15 |
+
vae = AutoencoderKL.from_pretrained(args.vae_model_path).to(
|
16 |
+
device, dtype=torch.float32
|
17 |
+
)
|
18 |
+
vae.requires_grad_(False)
|
19 |
+
|
20 |
+
image = load_image(args.image_path, size=(512, 512)).to(device, dtype=torch.float32)
|
21 |
+
image = image * 2 - 1
|
22 |
+
save_image(image / 2 + 0.5, f"{args.out_dir}/ori_image.png")
|
23 |
+
|
24 |
+
latents = vae.encode(image)["latent_dist"].mean
|
25 |
+
save_image(latents, f"{args.out_dir}/latents.png")
|
26 |
+
|
27 |
+
rec_image = vae.decode(latents, return_dict=False)[0]
|
28 |
+
save_image(rec_image / 2 + 0.5, f"{args.out_dir}/rec_image.png")
|
29 |
+
|
30 |
+
for param in vae.decoder.parameters():
|
31 |
+
param.requires_grad = True
|
32 |
+
|
33 |
+
loss_fn = nn.L1Loss()
|
34 |
+
optimizer = Adam(vae.decoder.parameters(), lr=args.learning_rate)
|
35 |
+
|
36 |
+
# Training loop
|
37 |
+
for epoch in range(args.num_epochs):
|
38 |
+
reconstructed = vae.decode(latents, return_dict=False)[0]
|
39 |
+
loss = loss_fn(reconstructed, image)
|
40 |
+
|
41 |
+
optimizer.zero_grad()
|
42 |
+
loss.backward()
|
43 |
+
optimizer.step()
|
44 |
+
|
45 |
+
print(f"Epoch {epoch+1}/{args.num_epochs}, Loss: {loss.item()}")
|
46 |
+
|
47 |
+
rec_image = vae.decode(latents, return_dict=False)[0]
|
48 |
+
save_image(rec_image / 2 + 0.5, f"{args.out_dir}/trained_rec_image.png")
|
49 |
+
vae.save_pretrained(
|
50 |
+
f"{args.out_dir}/trained_vae_{os.path.basename(args.image_path)}"
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
parser = argparse.ArgumentParser(
|
57 |
+
description="Train a VAE with given image and settings."
|
58 |
+
)
|
59 |
+
|
60 |
+
# Add arguments
|
61 |
+
parser.add_argument(
|
62 |
+
"--out_dir",
|
63 |
+
type=str,
|
64 |
+
default="./trained_vae/",
|
65 |
+
help="Output directory to save results",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--vae_model_path",
|
69 |
+
type=str,
|
70 |
+
required=True,
|
71 |
+
help="Path to the pretrained VAE model",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--image_path", type=str, required=True, help="Path to the input image"
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--learning_rate",
|
78 |
+
type=float,
|
79 |
+
default=1e-4,
|
80 |
+
help="Learning rate for the optimizer",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--num_epochs", type=int, default=75, help="Number of training epochs"
|
84 |
+
)
|
85 |
+
|
86 |
+
args = parser.parse_args()
|
87 |
+
main(args)
|
utils.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
from torchvision.transforms import ToTensor
|
6 |
+
from torchvision.utils import save_image
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import math
|
9 |
+
|
10 |
+
|
11 |
+
def register_attn_control(unet, controller, cache=None):
|
12 |
+
def attn_forward(self):
|
13 |
+
def forward(
|
14 |
+
hidden_states,
|
15 |
+
encoder_hidden_states=None,
|
16 |
+
attention_mask=None,
|
17 |
+
temb=None,
|
18 |
+
*args,
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
residual = hidden_states
|
22 |
+
if self.spatial_norm is not None:
|
23 |
+
hidden_states = self.spatial_norm(hidden_states, temb)
|
24 |
+
|
25 |
+
input_ndim = hidden_states.ndim
|
26 |
+
|
27 |
+
if input_ndim == 4:
|
28 |
+
batch_size, channel, height, width = hidden_states.shape
|
29 |
+
hidden_states = hidden_states.view(
|
30 |
+
batch_size, channel, height * width
|
31 |
+
).transpose(1, 2)
|
32 |
+
|
33 |
+
batch_size, sequence_length, _ = (
|
34 |
+
hidden_states.shape
|
35 |
+
if encoder_hidden_states is None
|
36 |
+
else encoder_hidden_states.shape
|
37 |
+
)
|
38 |
+
|
39 |
+
if attention_mask is not None:
|
40 |
+
attention_mask = self.prepare_attention_mask(
|
41 |
+
attention_mask, sequence_length, batch_size
|
42 |
+
)
|
43 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
44 |
+
# (batch, heads, source_length, target_length)
|
45 |
+
attention_mask = attention_mask.view(
|
46 |
+
batch_size, self.heads, -1, attention_mask.shape[-1]
|
47 |
+
)
|
48 |
+
|
49 |
+
if self.group_norm is not None:
|
50 |
+
hidden_states = self.group_norm(
|
51 |
+
hidden_states.transpose(1, 2)
|
52 |
+
).transpose(1, 2)
|
53 |
+
|
54 |
+
q = self.to_q(hidden_states)
|
55 |
+
is_self = encoder_hidden_states is None
|
56 |
+
|
57 |
+
if encoder_hidden_states is None:
|
58 |
+
encoder_hidden_states = hidden_states
|
59 |
+
elif self.norm_cross:
|
60 |
+
encoder_hidden_states = self.norm_encoder_hidden_states(
|
61 |
+
encoder_hidden_states
|
62 |
+
)
|
63 |
+
|
64 |
+
k = self.to_k(encoder_hidden_states)
|
65 |
+
v = self.to_v(encoder_hidden_states)
|
66 |
+
|
67 |
+
inner_dim = k.shape[-1]
|
68 |
+
head_dim = inner_dim // self.heads
|
69 |
+
|
70 |
+
q = q.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
71 |
+
k = k.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
72 |
+
v = v.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
73 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
74 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
75 |
+
hidden_states = F.scaled_dot_product_attention(
|
76 |
+
q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
77 |
+
)
|
78 |
+
if is_self and controller.cur_self_layer in controller.self_layers:
|
79 |
+
cache.add(q, k, v, hidden_states)
|
80 |
+
|
81 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
82 |
+
batch_size, -1, self.heads * head_dim
|
83 |
+
)
|
84 |
+
hidden_states = hidden_states.to(q.dtype)
|
85 |
+
|
86 |
+
# linear proj
|
87 |
+
hidden_states = self.to_out[0](hidden_states)
|
88 |
+
# dropout
|
89 |
+
hidden_states = self.to_out[1](hidden_states)
|
90 |
+
|
91 |
+
if input_ndim == 4:
|
92 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
93 |
+
batch_size, channel, height, width
|
94 |
+
)
|
95 |
+
if self.residual_connection:
|
96 |
+
hidden_states = hidden_states + residual
|
97 |
+
|
98 |
+
hidden_states = hidden_states / self.rescale_output_factor
|
99 |
+
|
100 |
+
if is_self:
|
101 |
+
controller.cur_self_layer += 1
|
102 |
+
|
103 |
+
return hidden_states
|
104 |
+
|
105 |
+
return forward
|
106 |
+
|
107 |
+
def modify_forward(net, count):
|
108 |
+
for name, subnet in net.named_children():
|
109 |
+
if net.__class__.__name__ == "Attention": # spatial Transformer layer
|
110 |
+
net.forward = attn_forward(net)
|
111 |
+
return count + 1
|
112 |
+
elif hasattr(net, "children"):
|
113 |
+
count = modify_forward(subnet, count)
|
114 |
+
return count
|
115 |
+
|
116 |
+
cross_att_count = 0
|
117 |
+
for net_name, net in unet.named_children():
|
118 |
+
cross_att_count += modify_forward(net, 0)
|
119 |
+
controller.num_self_layers = cross_att_count // 2
|
120 |
+
|
121 |
+
|
122 |
+
def load_image(image_path, size=None, mode="RGB"):
|
123 |
+
img = Image.open(image_path).convert(mode)
|
124 |
+
if size is None:
|
125 |
+
width, height = img.size
|
126 |
+
new_width = (width // 64) * 64
|
127 |
+
new_height = (height // 64) * 64
|
128 |
+
size = (new_width, new_height)
|
129 |
+
img = img.resize(size, Image.BICUBIC)
|
130 |
+
return ToTensor()(img).unsqueeze(0)
|
131 |
+
|
132 |
+
|
133 |
+
def adain(source, target, eps=1e-6):
|
134 |
+
source_mean, source_std = torch.mean(source, dim=(2, 3), keepdim=True), torch.std(
|
135 |
+
source, dim=(2, 3), keepdim=True
|
136 |
+
)
|
137 |
+
target_mean, target_std = torch.mean(
|
138 |
+
target, dim=(0, 2, 3), keepdim=True
|
139 |
+
), torch.std(target, dim=(0, 2, 3), keepdim=True)
|
140 |
+
normalized_source = (source - source_mean) / (source_std + eps)
|
141 |
+
transferred_source = normalized_source * target_std + target_mean
|
142 |
+
|
143 |
+
return transferred_source
|
144 |
+
|
145 |
+
|
146 |
+
class Controller:
|
147 |
+
def step(self):
|
148 |
+
self.cur_self_layer = 0
|
149 |
+
|
150 |
+
def __init__(self, self_layers=(0, 16)):
|
151 |
+
self.num_self_layers = -1
|
152 |
+
self.cur_self_layer = 0
|
153 |
+
self.self_layers = list(range(*self_layers))
|
154 |
+
|
155 |
+
|
156 |
+
class DataCache:
|
157 |
+
def __init__(self):
|
158 |
+
self.q = []
|
159 |
+
self.k = []
|
160 |
+
self.v = []
|
161 |
+
self.out = []
|
162 |
+
|
163 |
+
def clear(self):
|
164 |
+
self.q.clear()
|
165 |
+
self.k.clear()
|
166 |
+
self.v.clear()
|
167 |
+
self.out.clear()
|
168 |
+
|
169 |
+
def add(self, q, k, v, out):
|
170 |
+
self.q.append(q)
|
171 |
+
self.k.append(k)
|
172 |
+
self.v.append(v)
|
173 |
+
self.out.append(out)
|
174 |
+
|
175 |
+
def get(self):
|
176 |
+
return self.q.copy(), self.k.copy(), self.v.copy(), self.out.copy()
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
def show_image(path, title, display_height=3, title_fontsize=12):
|
181 |
+
img = Image.open(path)
|
182 |
+
img_width, img_height = img.size
|
183 |
+
|
184 |
+
aspect_ratio = img_width / img_height
|
185 |
+
display_width = display_height * aspect_ratio
|
186 |
+
|
187 |
+
plt.figure(figsize=(display_width, display_height))
|
188 |
+
plt.imshow(img)
|
189 |
+
plt.title(title,
|
190 |
+
fontsize=title_fontsize,
|
191 |
+
fontweight='bold',
|
192 |
+
pad=20)
|
193 |
+
plt.axis('off')
|
194 |
+
plt.tight_layout()
|
195 |
+
plt.show()
|
webui/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .tab_style_t2i import create_interface_style_t2i
|
3 |
+
from .tab_style_transfer import create_interface_style_transfer
|
4 |
+
from .tab_texture_synthesis import create_interface_texture_synthesis
|
5 |
+
from .runner import Runner
|
webui/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (420 Bytes). View file
|
|
webui/__pycache__/runner.cpython-310.pyc
ADDED
Binary file (4.38 kB). View file
|
|
webui/__pycache__/tab_style_t2i.cpython-310.pyc
ADDED
Binary file (2.5 kB). View file
|
|
webui/__pycache__/tab_style_transfer.cpython-310.pyc
ADDED
Binary file (2.21 kB). View file
|
|
webui/__pycache__/tab_texture_synthesis.cpython-310.pyc
ADDED
Binary file (2.25 kB). View file
|
|
webui/images/40.jpg
ADDED
![]() |
webui/images/42.jpg
ADDED
![]() |
webui/images/image_02_01.jpg
ADDED
![]() |
webui/images/lecun.png
ADDED
![]() |
webui/runner.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from diffusers import DDIMScheduler
|
4 |
+
from accelerate.utils import set_seed
|
5 |
+
from torchvision.transforms.functional import to_pil_image, to_tensor
|
6 |
+
|
7 |
+
from pipeline_sd import ADPipeline
|
8 |
+
from pipeline_sdxl import ADPipeline as ADXLPipeline
|
9 |
+
from utils import Controller
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
|
14 |
+
class Runner:
|
15 |
+
def __init__(self):
|
16 |
+
self.sd15 = None
|
17 |
+
self.sdxl = None
|
18 |
+
self.loss_fn = torch.nn.L1Loss(reduction="mean")
|
19 |
+
|
20 |
+
def load_pipeline(self, model_path_or_name):
|
21 |
+
|
22 |
+
if 'xl' in model_path_or_name and self.sdxl is None:
|
23 |
+
scheduler = DDIMScheduler.from_pretrained(os.path.join('./checkpoints', model_path_or_name), subfolder="scheduler")
|
24 |
+
self.sdxl = ADXLPipeline.from_pretrained(os.path.join('./checkpoints', model_path_or_name), scheduler=scheduler, safety_checker=None)
|
25 |
+
self.sdxl.classifier = self.sdxl.unet
|
26 |
+
elif self.sd15 is None:
|
27 |
+
scheduler = DDIMScheduler.from_pretrained(os.path.join('./checkpoints', model_path_or_name), subfolder="scheduler")
|
28 |
+
self.sd15 = ADPipeline.from_pretrained(os.path.join('./checkpoints', model_path_or_name), scheduler=scheduler, safety_checker=None)
|
29 |
+
self.sd15.classifier = self.sd15.unet
|
30 |
+
|
31 |
+
def preprocecss(self, image: Image.Image, height=None, width=None):
|
32 |
+
if width is None or height is None:
|
33 |
+
width, height = image.size
|
34 |
+
new_width = (width // 64) * 64
|
35 |
+
new_height = (height // 64) * 64
|
36 |
+
size = (new_width, new_height)
|
37 |
+
image = image.resize(size, Image.BICUBIC)
|
38 |
+
return to_tensor(image).unsqueeze(0)
|
39 |
+
|
40 |
+
def run_style_transfer(self, content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model, **kwargs):
|
41 |
+
self.load_pipeline(model)
|
42 |
+
|
43 |
+
content_image = self.preprocecss(content_image)
|
44 |
+
style_image = self.preprocecss(style_image, height=512, width=512)
|
45 |
+
|
46 |
+
height, width = content_image.shape[-2:]
|
47 |
+
set_seed(seed)
|
48 |
+
controller = Controller(self_layers=(10, 16))
|
49 |
+
result = self.sd15.optimize(
|
50 |
+
lr=lr,
|
51 |
+
batch_size=1,
|
52 |
+
iters=1,
|
53 |
+
width=width,
|
54 |
+
height=height,
|
55 |
+
weight=content_weight,
|
56 |
+
controller=controller,
|
57 |
+
style_image=style_image,
|
58 |
+
content_image=content_image,
|
59 |
+
mixed_precision=mixed_precision,
|
60 |
+
num_inference_steps=num_steps,
|
61 |
+
enable_gradient_checkpoint=False,
|
62 |
+
)
|
63 |
+
output_image = to_pil_image(result[0])
|
64 |
+
del result
|
65 |
+
torch.cuda.empty_cache()
|
66 |
+
return [output_image]
|
67 |
+
|
68 |
+
def run_style_t2i_generation(self, style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model):
|
69 |
+
self.load_pipeline(model)
|
70 |
+
|
71 |
+
use_xl = 'xl' in model
|
72 |
+
height, width = (1024, 1024) if 'xl' in model else (512, 512)
|
73 |
+
style_image = self.preprocecss(style_image, height=height, width=width)
|
74 |
+
|
75 |
+
set_seed(seed)
|
76 |
+
self_layers = (64, 70) if use_xl else (10, 16)
|
77 |
+
|
78 |
+
controller = Controller(self_layers=self_layers)
|
79 |
+
|
80 |
+
pipeline = self.sdxl if use_xl else self.sd15
|
81 |
+
images = pipeline.sample(
|
82 |
+
controller=controller,
|
83 |
+
iters=iterations,
|
84 |
+
lr=lr,
|
85 |
+
adain=is_adain,
|
86 |
+
height=height,
|
87 |
+
width=width,
|
88 |
+
mixed_precision=mixed_precision,
|
89 |
+
style_image=style_image,
|
90 |
+
prompt=prompt,
|
91 |
+
negative_prompt=negative_prompt,
|
92 |
+
guidance_scale=guidance_scale,
|
93 |
+
num_inference_steps=num_steps,
|
94 |
+
num_images_per_prompt=num_images_per_prompt,
|
95 |
+
enable_gradient_checkpoint=False
|
96 |
+
)
|
97 |
+
output_images = [to_pil_image(image) for image in images]
|
98 |
+
|
99 |
+
del images
|
100 |
+
torch.cuda.empty_cache()
|
101 |
+
return output_images
|
102 |
+
|
103 |
+
def run_texture_synthesis(self, texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model):
|
104 |
+
self.load_pipeline(model)
|
105 |
+
|
106 |
+
texture_image = self.preprocecss(texture_image, height=512, width=512)
|
107 |
+
|
108 |
+
set_seed(seed)
|
109 |
+
controller = Controller(self_layers=(10, 16))
|
110 |
+
|
111 |
+
if synthesis_way == 'Sampling':
|
112 |
+
results = self.sd15.sample(
|
113 |
+
lr=lr,
|
114 |
+
adain=False,
|
115 |
+
iters=iterations,
|
116 |
+
width=width,
|
117 |
+
height=height,
|
118 |
+
weight=0.,
|
119 |
+
controller=controller,
|
120 |
+
style_image=texture_image,
|
121 |
+
content_image=None,
|
122 |
+
prompt="",
|
123 |
+
negative_prompt="",
|
124 |
+
mixed_precision=mixed_precision,
|
125 |
+
num_inference_steps=num_steps,
|
126 |
+
guidance_scale=1.,
|
127 |
+
num_images_per_prompt=num_images_per_prompt,
|
128 |
+
enable_gradient_checkpoint=False,
|
129 |
+
)
|
130 |
+
elif synthesis_way == 'MultiDiffusion':
|
131 |
+
results = self.sd15.panorama(
|
132 |
+
lr=lr,
|
133 |
+
iters=iterations,
|
134 |
+
width=width,
|
135 |
+
height=height,
|
136 |
+
weight=0.,
|
137 |
+
controller=controller,
|
138 |
+
style_image=texture_image,
|
139 |
+
content_image=None,
|
140 |
+
prompt="",
|
141 |
+
negative_prompt="",
|
142 |
+
stride=8,
|
143 |
+
view_batch_size=8,
|
144 |
+
mixed_precision=mixed_precision,
|
145 |
+
num_inference_steps=num_steps,
|
146 |
+
guidance_scale=1.,
|
147 |
+
num_images_per_prompt=num_images_per_prompt,
|
148 |
+
enable_gradient_checkpoint=False,
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
raise ValueError
|
152 |
+
|
153 |
+
output_images = [to_pil_image(image) for image in results]
|
154 |
+
del results
|
155 |
+
torch.cuda.empty_cache()
|
156 |
+
return output_images
|
157 |
+
|
webui/tab_style_t2i.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
|
6 |
+
def create_interface_style_t2i(runner):
|
7 |
+
with gr.Blocks():
|
8 |
+
with gr.Row():
|
9 |
+
gr.Markdown('1. Upload the style image and text your prompt.\n'
|
10 |
+
'2. Choose the generative model.\n'
|
11 |
+
'3. (Optional) Customize the configurations below as needed.\n'
|
12 |
+
'4. Cilck `Run` to start generation.')
|
13 |
+
|
14 |
+
with gr.Row():
|
15 |
+
with gr.Column():
|
16 |
+
style_image = gr.Image(label='Input Style Image', type='pil', interactive=True,
|
17 |
+
value=Image.open('examples/s1.jpg').convert('RGB') if os.path.exists('examples/s1.jpg') else None)
|
18 |
+
prompt = gr.Textbox(label='Prompt', value='A rocket')
|
19 |
+
negative_prompt = gr.Textbox(label='Negative Prompt', value='')
|
20 |
+
|
21 |
+
base_model_list = ['stable-diffusion-v1-5', 'stable-diffusion-xl-base-1.0']
|
22 |
+
model = gr.Radio(choices=base_model_list, label='Select a Base Model', value='stable-diffusion-xl-base-1.0')
|
23 |
+
|
24 |
+
run_button = gr.Button(value='Run')
|
25 |
+
|
26 |
+
gr.Examples(
|
27 |
+
[[Image.open('./webui/images/image_02_01.jpg').convert('RGB'), 'A rocket', 'stable-diffusion-xl-base-1.0']],
|
28 |
+
[style_image, prompt, model]
|
29 |
+
)
|
30 |
+
|
31 |
+
with gr.Column():
|
32 |
+
with gr.Accordion('Options', open=True):
|
33 |
+
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1., maximum=30., value=7.5, step=0.1)
|
34 |
+
height = gr.Number(label='Height', value=1024, precision=0, minimum=2, maximum=4096)
|
35 |
+
width = gr.Number(label='Width', value=1024, precision=0, minimum=2, maximum=4096)
|
36 |
+
seed = gr.Number(label='Seed', value=2025, precision=0, minimum=0, maximum=2**31)
|
37 |
+
num_steps = gr.Slider(label='Number of Steps', minimum=1, maximum=1000, value=50, step=1)
|
38 |
+
iterations = gr.Slider(label='Iterations', minimum=0, maximum=10, value=2, step=1)
|
39 |
+
lr = gr.Slider(label='Learning Rate', minimum=0.01, maximum=0.5, value=0.015, step=0.001)
|
40 |
+
num_images_per_prompt = gr.Slider(label='Num Images Per Prompt', minimum=1, maximum=10, value=1, step=1)
|
41 |
+
mixed_precision = gr.Radio(choices=['bf16', 'no'], value='bf16', label='Mixed Precision')
|
42 |
+
is_adain = gr.Checkbox(label='Adain', value=True,)
|
43 |
+
|
44 |
+
with gr.Column():
|
45 |
+
gr.Markdown('#### Output Image:\n')
|
46 |
+
result_gallery = gr.Gallery(label='Output', elem_id='gallery', columns=2, height='auto', preview=True)
|
47 |
+
|
48 |
+
ips = [style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model]
|
49 |
+
|
50 |
+
run_button.click(fn=runner.run_style_t2i_generation, inputs=ips, outputs=[result_gallery])
|
51 |
+
|
webui/tab_style_transfer.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
|
6 |
+
def create_interface_style_transfer(runner):
|
7 |
+
with gr.Blocks():
|
8 |
+
with gr.Row():
|
9 |
+
gr.Markdown('1. Upload the content and style images as inputs.\n'
|
10 |
+
'2. (Optional) Customize the configurations below as needed.\n'
|
11 |
+
'3. Cilck `Run` to start transfer.')
|
12 |
+
|
13 |
+
with gr.Row():
|
14 |
+
with gr.Column():
|
15 |
+
with gr.Row():
|
16 |
+
content_image = gr.Image(label='Input Content Image', type='pil', interactive=True,
|
17 |
+
value=Image.open('examples/c1.jpg').convert('RGB') if os.path.exists('examples/c1.jpg') else None)
|
18 |
+
style_image = gr.Image(label='Input Style Image', type='pil', interactive=True,
|
19 |
+
value=Image.open('examples/s1.jpg').convert('RGB') if os.path.exists('examples/s1.jpg') else None)
|
20 |
+
|
21 |
+
run_button = gr.Button(value='Run')
|
22 |
+
|
23 |
+
with gr.Accordion('Options', open=True):
|
24 |
+
seed = gr.Number(label='Seed', value=2025, precision=0, minimum=0, maximum=2**31)
|
25 |
+
num_steps = gr.Slider(label='Number of Steps', minimum=1, maximum=1000, value=200, step=1)
|
26 |
+
lr = gr.Slider(label='Learning Rate', minimum=0.01, maximum=0.5, value=0.05, step=0.01)
|
27 |
+
content_weight = gr.Slider(label='Content Weight', minimum=0., maximum=1., value=0.25, step=0.001)
|
28 |
+
mixed_precision = gr.Radio(choices=['bf16', 'no'], value='bf16', label='Mixed Precision')
|
29 |
+
|
30 |
+
base_model_list = ['stable-diffusion-v1-5',]
|
31 |
+
model = gr.Radio(choices=base_model_list, label='Select a Base Model', value='stable-diffusion-v1-5')
|
32 |
+
|
33 |
+
with gr.Column():
|
34 |
+
gr.Markdown('#### Output Image:\n')
|
35 |
+
result_gallery = gr.Gallery(label='Output', elem_id='gallery', columns=2, height='auto', preview=True)
|
36 |
+
|
37 |
+
gr.Examples(
|
38 |
+
[[Image.open('./webui/images/lecun.png').convert('RGB'), Image.open('./webui/images/40.jpg').convert('RGB'), 300, 0.23]],
|
39 |
+
[content_image, style_image, num_steps, content_weight]
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
ips = [content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model]
|
44 |
+
|
45 |
+
run_button.click(fn=runner.run_style_transfer, inputs=ips, outputs=[result_gallery])
|
webui/tab_texture_synthesis.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
|
6 |
+
def create_interface_texture_synthesis(runner):
|
7 |
+
with gr.Blocks():
|
8 |
+
with gr.Row():
|
9 |
+
gr.Markdown('1. Upload the texture image as input.\n'
|
10 |
+
'2. (Optional) Customize the configurations below as needed.\n'
|
11 |
+
'3. Cilck `Run` to start synthesis.')
|
12 |
+
|
13 |
+
with gr.Row():
|
14 |
+
with gr.Column():
|
15 |
+
with gr.Row():
|
16 |
+
texture_image = gr.Image(label='Input Texture Image', type='pil', interactive=True,
|
17 |
+
value=Image.open('examples/s1.jpg').convert('RGB') if os.path.exists('examples/s1.jpg') else None)
|
18 |
+
|
19 |
+
run_button = gr.Button(value='Run')
|
20 |
+
|
21 |
+
with gr.Accordion('Options', open=True):
|
22 |
+
height = gr.Number(label='Height', value=512, precision=0, minimum=2, maximum=4096)
|
23 |
+
width = gr.Number(label='Width', value=1024, precision=0, minimum=2, maximum=4096)
|
24 |
+
seed = gr.Number(label='Seed', value=2025, precision=0, minimum=0, maximum=2**31)
|
25 |
+
num_steps = gr.Slider(label='Number of Steps', minimum=1, maximum=1000, value=200, step=1)
|
26 |
+
iterations = gr.Slider(label='Iterations', minimum=0, maximum=10, value=2, step=1)
|
27 |
+
lr = gr.Slider(label='Learning Rate', minimum=0.01, maximum=0.5, value=0.05, step=0.01)
|
28 |
+
mixed_precision = gr.Radio(choices=['bf16', 'no'], value='bf16', label='Mixed Precision')
|
29 |
+
num_images_per_prompt = gr.Slider(label='Num Images Per Prompt', minimum=1, maximum=10, value=1, step=1)
|
30 |
+
|
31 |
+
base_model_list = ['stable-diffusion-v1-5',]
|
32 |
+
model = gr.Radio(choices=base_model_list, label='Select a Base Model', value='stable-diffusion-v1-5')
|
33 |
+
synthesis_way = gr.Radio(['Sampling', 'MultiDiffusion'], label='Synthesis Way', value='MultiDiffusion')
|
34 |
+
|
35 |
+
with gr.Column():
|
36 |
+
gr.Markdown('#### Output Image:\n')
|
37 |
+
result_gallery = gr.Gallery(label='Output', elem_id='gallery', columns=2, height='auto', preview=True)
|
38 |
+
|
39 |
+
gr.Examples(
|
40 |
+
[[Image.open('./webui/images/42.jpg').convert('RGB'), 'MultiDiffusion', 512, 1024]],
|
41 |
+
[texture_image, synthesis_way, height, width]
|
42 |
+
)
|
43 |
+
ips = [texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model]
|
44 |
+
|
45 |
+
run_button.click(fn=runner.run_texture_synthesis, inputs=ips, outputs=[result_gallery])
|
46 |
+
|