Klayand commited on
Commit
ef33138
·
1 Parent(s): aa21179
Files changed (36) hide show
  1. diffusion_pipeline/__pycache__/gemma.cpython-310.pyc +0 -0
  2. diffusion_pipeline/__pycache__/lora.cpython-310.pyc +0 -0
  3. diffusion_pipeline/__pycache__/lora.cpython-312.pyc +0 -0
  4. diffusion_pipeline/__pycache__/refine_model.cpython-310.pyc +0 -0
  5. diffusion_pipeline/__pycache__/refine_model.cpython-312.pyc +0 -0
  6. diffusion_pipeline/__pycache__/refine_model.cpython-38.pyc +0 -0
  7. diffusion_pipeline/__pycache__/sd35_cfgpp.cpython-310.pyc +0 -0
  8. diffusion_pipeline/__pycache__/sd35_pipeline.cpython-310.pyc +0 -0
  9. diffusion_pipeline/__pycache__/sd35_pipeline.cpython-312.pyc +0 -0
  10. diffusion_pipeline/__pycache__/sd35_pipeline.cpython-38.pyc +0 -0
  11. diffusion_pipeline/__pycache__/sdxl_cfgpp.cpython-310.pyc +0 -0
  12. diffusion_pipeline/__pycache__/sdxl_pipeline.cpython-310.pyc +0 -0
  13. diffusion_pipeline/__pycache__/sdxl_pipeline.cpython-312.pyc +0 -0
  14. diffusion_pipeline/__pycache__/sdxl_pipeline.cpython-38.pyc +0 -0
  15. diffusion_pipeline/__pycache__/stable_diffusion_35_meta_learning.cpython-310.pyc +0 -0
  16. diffusion_pipeline/__pycache__/stable_diffusion_35_meta_learning2.cpython-310.pyc +0 -0
  17. diffusion_pipeline/__pycache__/stable_diffusion_35_pag.cpython-310.pyc +0 -0
  18. diffusion_pipeline/__pycache__/stable_diffusion_35_projection.cpython-310.pyc +0 -0
  19. diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_cfg.cpython-310.pyc +0 -0
  20. diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_cfg.cpython-312.pyc +0 -0
  21. diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_cfg.cpython-38.pyc +0 -0
  22. diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_zigzag.cpython-310.pyc +0 -0
  23. diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_zigzag_sec.cpython-310.pyc +0 -0
  24. diffusion_pipeline/__pycache__/stable_diffusion_35_zigzag.cpython-310.pyc +0 -0
  25. diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_cfg.cpython-310.pyc +0 -0
  26. diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_cfg.cpython-312.pyc +0 -0
  27. diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_cfg.cpython-38.pyc +0 -0
  28. diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_zigzag.cpython-310.pyc +0 -0
  29. diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_zigzag_test.cpython-310.pyc +0 -0
  30. diffusion_pipeline/__pycache__/stable_diffusion_xl_zigzag.cpython-310.pyc +0 -0
  31. diffusion_pipeline/gemma.py +53 -0
  32. diffusion_pipeline/lora.py +62 -0
  33. diffusion_pipeline/refine_model.py +526 -0
  34. diffusion_pipeline/sd35_pipeline.py +0 -0
  35. diffusion_pipeline/sdxl_pipeline.py +0 -0
  36. sample_img.py +237 -0
diffusion_pipeline/__pycache__/gemma.cpython-310.pyc ADDED
Binary file (2.33 kB). View file
 
diffusion_pipeline/__pycache__/lora.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
diffusion_pipeline/__pycache__/lora.cpython-312.pyc ADDED
Binary file (4.43 kB). View file
 
diffusion_pipeline/__pycache__/refine_model.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
diffusion_pipeline/__pycache__/refine_model.cpython-312.pyc ADDED
Binary file (37.9 kB). View file
 
diffusion_pipeline/__pycache__/refine_model.cpython-38.pyc ADDED
Binary file (16.7 kB). View file
 
diffusion_pipeline/__pycache__/sd35_cfgpp.cpython-310.pyc ADDED
Binary file (48.4 kB). View file
 
diffusion_pipeline/__pycache__/sd35_pipeline.cpython-310.pyc ADDED
Binary file (46.3 kB). View file
 
diffusion_pipeline/__pycache__/sd35_pipeline.cpython-312.pyc ADDED
Binary file (78.5 kB). View file
 
diffusion_pipeline/__pycache__/sd35_pipeline.cpython-38.pyc ADDED
Binary file (45.6 kB). View file
 
diffusion_pipeline/__pycache__/sdxl_cfgpp.cpython-310.pyc ADDED
Binary file (77.6 kB). View file
 
diffusion_pipeline/__pycache__/sdxl_pipeline.cpython-310.pyc ADDED
Binary file (72.3 kB). View file
 
diffusion_pipeline/__pycache__/sdxl_pipeline.cpython-312.pyc ADDED
Binary file (118 kB). View file
 
diffusion_pipeline/__pycache__/sdxl_pipeline.cpython-38.pyc ADDED
Binary file (72.1 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_meta_learning.cpython-310.pyc ADDED
Binary file (40.9 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_meta_learning2.cpython-310.pyc ADDED
Binary file (41.1 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_pag.cpython-310.pyc ADDED
Binary file (33.2 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_projection.cpython-310.pyc ADDED
Binary file (40.8 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_cfg.cpython-310.pyc ADDED
Binary file (55.2 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_cfg.cpython-312.pyc ADDED
Binary file (73.3 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_cfg.cpython-38.pyc ADDED
Binary file (46.6 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_zigzag.cpython-310.pyc ADDED
Binary file (36.6 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_smooth_zigzag_sec.cpython-310.pyc ADDED
Binary file (36.3 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_35_zigzag.cpython-310.pyc ADDED
Binary file (36.5 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_cfg.cpython-310.pyc ADDED
Binary file (73.2 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_cfg.cpython-312.pyc ADDED
Binary file (82.2 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_cfg.cpython-38.pyc ADDED
Binary file (73 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_zigzag.cpython-310.pyc ADDED
Binary file (54.2 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_xl_smooth_zigzag_test.cpython-310.pyc ADDED
Binary file (54.9 kB). View file
 
diffusion_pipeline/__pycache__/stable_diffusion_xl_zigzag.cpython-310.pyc ADDED
Binary file (75.6 kB). View file
 
diffusion_pipeline/gemma.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, Gemma2ForTokenClassification, BitsAndBytesConfig
5
+
6
+ import os
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+ torch.set_float32_matmul_precision("high")
9
+
10
+ def repeat_function(xs, max_length = 128):
11
+ new_xs = []
12
+ for x in xs:
13
+ if x.shape[1] >= max_length-1:
14
+ new_xs.append(x[:,:max_length-1,:])
15
+ else:
16
+ new_xs.append(x)
17
+ xs = new_xs
18
+ mean_xs = [x.mean(1,keepdim=True).expand(-1,max_length - x.shape[1],-1) for x in xs]
19
+ xs = [torch.cat([x,mean_x],1) for mean_x, x in zip(mean_xs, xs)]
20
+ return xs
21
+
22
+ class Gemma2Model(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", )
26
+ self.tokenizer_max_length = 128
27
+ # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
28
+
29
+ self.model = Gemma2ForTokenClassification.from_pretrained(
30
+ "google/gemma-2-2b",
31
+ # device_map="auto",
32
+ # quantization_config=quantization_config,
33
+ ).float()
34
+ self.model.score = nn.Identity()
35
+
36
+ @torch.no_grad()
37
+ def forward(self, input_prompt):
38
+ input_prompt = list(input_prompt)
39
+ outputs = []
40
+ for _input_prompt in input_prompt:
41
+ input_ids = self.tokenizer(_input_prompt, add_special_tokens=False, max_length=77, return_tensors="pt").to("cuda")
42
+ _outputs = self.model(**input_ids)["logits"]
43
+ outputs.append(_outputs)
44
+ outputs = repeat_function(outputs)
45
+ outputs = torch.cat(outputs,0)
46
+ return outputs
47
+
48
+ if __name__ == "__main__":
49
+ model = Gemma2Model().cuda()
50
+ input_text = ["Write me a poem about Machine Learning.", "Write me a poem about Deep Learning."]
51
+ print(model(input_text))
52
+ print(model(input_text)[0].shape)
53
+ print(model(input_text).shape)
diffusion_pipeline/lora.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class LoRALayer(torch.nn.Module):
5
+ def __init__(self, in_dim, out_dim, rank, alpha):
6
+ super().__init__()
7
+ std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
8
+ self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
9
+ self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
10
+ self.alpha = alpha
11
+
12
+ def forward(self, x):
13
+ x = self.alpha * (x @ self.A @ self.B)
14
+ return x
15
+
16
+ class LinearWithLoRA(torch.nn.Module):
17
+ def __init__(self, linear, rank, alpha,
18
+ weak_lora_alpha=0.1, number_of_lora=1):
19
+ super().__init__()
20
+ self.linear = linear
21
+ self.lora = nn.ModuleList([LoRALayer(
22
+ linear.in_features, linear.out_features, rank, alpha
23
+ ) for _ in range(number_of_lora)])
24
+ self.use_lora = True
25
+ self.lora_idx = 0
26
+
27
+ def forward(self, x):
28
+ if self.use_lora:
29
+ return self.linear(x) + self.lora[self.lora_idx](x)
30
+ else:
31
+ return self.linear(x)
32
+
33
+ def replace_linear_with_lora(module, rank=64, alpha=1., tag=0, weak_lora_alpha=0.1, number_of_lora=1):
34
+ for name, child in module.named_children():
35
+ if isinstance(child, nn.Linear):
36
+ setattr(module, name, LinearWithLoRA(child, rank, alpha, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora))
37
+ else:
38
+ replace_linear_with_lora(child, rank, alpha, tag, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora)
39
+
40
+
41
+ def lora_false(model, lora_idx=0):
42
+ for name, module in model.named_modules():
43
+ if isinstance(module, LinearWithLoRA):
44
+ module.use_lora = False
45
+ module.lora_idx = lora_idx
46
+
47
+ def lora_true(model, lora_idx=0):
48
+ for name, module in model.named_modules():
49
+ if isinstance(module, LinearWithLoRA):
50
+ module.use_lora = True
51
+ module.lora_idx = lora_idx
52
+ for i, lora in enumerate(module.lora):
53
+ if i != lora_idx:
54
+ lora.A.requires_grad = False
55
+ lora.B.requires_grad = False
56
+ if lora.A.grad is not None:
57
+ del lora.A.grad
58
+ if lora.B.grad is not None:
59
+ del lora.B.grad
60
+ else:
61
+ lora.A.requires_grad = True
62
+ lora.B.requires_grad = True
diffusion_pipeline/refine_model.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ import json
6
+ import torch.nn.functional as F
7
+ import random
8
+ from torch.utils.data import Dataset
9
+ from transformers import AutoTokenizer
10
+ from glob import glob
11
+ import math
12
+ from PIL import Image
13
+ device = torch.device('cuda')
14
+ import numpy as np
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from diffusers.utils import logging
21
+ from diffusers.models.embeddings import PatchEmbed
22
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
23
+ from diffusers.models.attention import BasicTransformerBlock
24
+ from diffusers.models.normalization import AdaLayerNormContinuous
25
+ from torchvision import transforms
26
+
27
+ def add_hook_to_module(model, module_name):
28
+ outputs = []
29
+ def hook(module, input, output):
30
+ outputs.append(output)
31
+ module = dict(model.named_modules()).get(module_name)
32
+ if module is None:
33
+ raise ValueError(f"can't find module {module_name}")
34
+ hook_handle = module.register_forward_hook(hook)
35
+ return hook_handle, outputs
36
+
37
+ class PromptSD35Net(nn.Module):
38
+
39
+ def __init__(self,
40
+ sample_size: int = 128,
41
+ patch_size: int = 2,
42
+ in_channels: int = 16,
43
+ num_layers: int = 8,
44
+ attention_head_dim: int = 64,
45
+ num_attention_heads: int = 24,
46
+ out_channels: int = 16,
47
+ pos_embed_max_size: int = 192
48
+ ):
49
+ super().__init__()
50
+ self.sample_size = sample_size
51
+ self.patch_size = patch_size
52
+ self.in_channels = in_channels
53
+ self.num_layers = num_layers
54
+ self.attention_head_dim = attention_head_dim
55
+ self.num_attention_heads = num_attention_heads
56
+ self.out_channels = out_channels
57
+ self.pos_embed_max_size = pos_embed_max_size
58
+ self.inner_dim = self.num_attention_heads * self.attention_head_dim
59
+
60
+ self.pos_embed = PatchEmbed(
61
+ height=self.sample_size,
62
+ width=self.sample_size,
63
+ patch_size=self.patch_size,
64
+ in_channels=self.in_channels,
65
+ embed_dim=self.inner_dim,
66
+ pos_embed_max_size=pos_embed_max_size
67
+ )
68
+
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ BasicTransformerBlock(
72
+ dim=self.inner_dim,
73
+ num_attention_heads=self.num_attention_heads,
74
+ attention_head_dim=self.attention_head_dim,
75
+ ff_inner_dim=2*self.inner_dim # mult should be 4 by default
76
+ )
77
+ for i in range(self.num_layers)
78
+ ]
79
+ )
80
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
81
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
82
+
83
+ self.noise_shape = (1, 16, 128, 128) # (667, 4096)
84
+ self.pre8_linear = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
85
+ self.pre16_linear = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
86
+ self.pre24_linear = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
87
+
88
+ self.pre8_linear2 = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
89
+ self.pre16_linear2 = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
90
+ self.pre24_linear2 = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
91
+
92
+ self.last_linear = nn.Sequential(nn.Linear(4096, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
93
+ # self.last_linear2 = nn.Sequential(nn.Linear(667, 32))
94
+ self.skip_connection2 = nn.Linear(4096, 1, bias=False)
95
+ self.skip_connection = nn.Linear(667, 32, bias=False)
96
+ self.trans_linear = nn.Linear(666+1+4096, 1536, bias=False)
97
+ nn.init.constant_(self.skip_connection.weight.data, 0)
98
+ nn.init.constant_(self.trans_linear.weight.data, 0)
99
+ nn.init.constant_(self.trans_linear.weight.data, 0)
100
+ nn.init.constant_(self.pre8_linear[-1].weight.data, 0)
101
+ nn.init.constant_(self.pre16_linear[-1].weight.data, 0)
102
+ nn.init.constant_(self.pre24_linear[-1].weight.data, 0)
103
+ nn.init.constant_(self.pre8_linear2[-1].weight.data, 0)
104
+ nn.init.constant_(self.pre16_linear2[-1].weight.data, 0)
105
+ nn.init.constant_(self.pre24_linear2[-1].weight.data, 0)
106
+
107
+ def forward(self, noise: torch.Tensor, _s, _v, _d, _pool_embedding) -> torch.Tensor:
108
+
109
+ assert noise is not None
110
+ _ori_v = _v.clone()
111
+ _v = torch.stack([torch.diag(_v[jj]) for jj in range(_v.shape[0])], dim=0)
112
+ positive_embedding = _s.permute(0, 2, 1) @ _v @ _d # [2, 64, 666] [2, 64] [2, 64, 4096]
113
+ pool_embedding = _pool_embedding[:, None, :]
114
+ embedding = torch.cat([positive_embedding, pool_embedding], dim=1)
115
+ bs = noise.shape[0]
116
+ height, width = noise.shape[-2:]
117
+ embed_8 = embedding
118
+ embed_16 = embedding
119
+ embed_24 = embedding
120
+ scale_8 = self.pre8_linear2(embed_8).mean(1)
121
+ scale_16 = self.pre16_linear2(embed_16).mean(1)
122
+ scale_24 = self.pre24_linear2(embed_24).mean(1)
123
+ embed_8 = self.pre8_linear(embed_8).mean(1)
124
+ embed_16 = self.pre16_linear(embed_16).mean(1)
125
+ embed_24 = self.pre24_linear(embed_24).mean(1)
126
+ embed_last = self.last_linear(embedding).mean(1)
127
+ embed_trans = self.trans_linear(torch.cat([_s, _ori_v[...,None], _d], dim=2)).mean(1)
128
+ skip_embedding = self.skip_connection(self.skip_connection2(embedding).permute(0,2,1)).permute(0,2,1)
129
+ scale_skip, embed_skip = skip_embedding.chunk(2,dim=1)
130
+
131
+ ori_noise = noise * (scale_skip[...,None]) + embed_skip[...,None]
132
+ noise = self.pos_embed(noise)
133
+ noise = noise * (1 + scale_8[:, None, :] + embed_trans[:, None, :]) + embed_8[:, None, :]
134
+ scale_list = [scale_16, scale_24]
135
+ embed_list = [embed_16, embed_24]
136
+ for _ii, block in enumerate(self.transformer_blocks):
137
+ noise = block(noise)
138
+ if len(scale_list)!=0 and len(embed_list)!=0:
139
+ noise = noise * (1 + scale_list[int(_ii//4)][:, None, :] + embed_trans[:, None, :]) + embed_list[int(_ii//4)][:, None, :]
140
+
141
+ hidden_states = noise
142
+ hidden_states = self.norm_out(hidden_states, embed_last)
143
+ hidden_states = self.proj_out(hidden_states)
144
+
145
+ # unpatchify
146
+ patch_size = self.patch_size
147
+ height = height // patch_size
148
+ width = width // patch_size
149
+
150
+ hidden_states = hidden_states.reshape(
151
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
152
+ )
153
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
154
+ output = hidden_states.reshape(
155
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
156
+ )
157
+ return output + ori_noise
158
+
159
+ def weak_load_state_dict(self, state_dict: os.Mapping[str, torch.any], strict: bool = True, assign: bool = False):
160
+ return load_filtered_state_dict(self, state_dict)
161
+
162
+ class PromptSDXLNet(nn.Module):
163
+
164
+ def __init__(self,
165
+ sample_size: int = 128,
166
+ patch_size: int = 2,
167
+ in_channels: int = 4,
168
+ num_layers: int = 4,
169
+ attention_head_dim: int = 64,
170
+ num_attention_heads: int = 24,
171
+ out_channels: int = 4,
172
+ pos_embed_max_size: int = 192
173
+ ):
174
+ super().__init__()
175
+ self.sample_size = sample_size
176
+ self.patch_size = patch_size
177
+ self.in_channels = in_channels
178
+ self.num_layers = num_layers
179
+ self.attention_head_dim = attention_head_dim
180
+ self.num_attention_heads = num_attention_heads
181
+ self.out_channels = out_channels
182
+ self.pos_embed_max_size = pos_embed_max_size
183
+ self.inner_dim = self.num_attention_heads * self.attention_head_dim
184
+
185
+ self.pos_embed = PatchEmbed(
186
+ height=self.sample_size,
187
+ width=self.sample_size,
188
+ patch_size=self.patch_size,
189
+ in_channels=self.in_channels,
190
+ embed_dim=self.inner_dim,
191
+ pos_embed_max_size=pos_embed_max_size
192
+ )
193
+
194
+ self.transformer_blocks = nn.ModuleList(
195
+ [
196
+ BasicTransformerBlock(
197
+ dim=self.inner_dim,
198
+ num_attention_heads=self.num_attention_heads,
199
+ attention_head_dim=self.attention_head_dim,
200
+ ff_inner_dim=2*self.inner_dim # mult should be 4 by default
201
+ )
202
+ for i in range(self.num_layers)
203
+ ]
204
+ )
205
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
206
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
207
+
208
+ self.noise_shape = (1, 4, 128, 128)
209
+ self.pre8_linear = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
210
+ self.pre16_linear = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
211
+ self.pre24_linear = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
212
+
213
+ self.pre8_linear2 = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
214
+ self.pre16_linear2 = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
215
+ self.pre24_linear2 = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
216
+
217
+ self.last_linear = nn.Sequential(nn.Linear(2048, 128), nn.SiLU(), nn.LayerNorm(128), nn.Linear(128, 1536))
218
+ # self.last_linear2 = nn.Sequential(nn.Linear(667, 32))
219
+ self.skip_connection2 = nn.Linear(2048, 1, bias=False)
220
+ self.skip_connection = nn.Linear(154+1, 8, bias=False)
221
+ self.trans_linear = nn.Linear(154+1+2048, 1536, bias=False)
222
+ self.pool_prompt_linear = nn.Linear(2560, 2048, bias=False)
223
+ nn.init.constant_(self.skip_connection.weight.data, 0)
224
+ nn.init.constant_(self.trans_linear.weight.data, 0)
225
+ nn.init.constant_(self.trans_linear.weight.data, 0)
226
+ nn.init.constant_(self.pre8_linear[-1].weight.data, 0)
227
+ nn.init.constant_(self.pre16_linear[-1].weight.data, 0)
228
+ nn.init.constant_(self.pre24_linear[-1].weight.data, 0)
229
+ nn.init.constant_(self.pre8_linear2[-1].weight.data, 0)
230
+ nn.init.constant_(self.pre16_linear2[-1].weight.data, 0)
231
+ nn.init.constant_(self.pre24_linear2[-1].weight.data, 0)
232
+
233
+ def forward(self, noise: torch.Tensor, _s, _v, _d, _pool_embedding) -> torch.Tensor:
234
+
235
+ assert noise is not None
236
+ _ori_v = _v.clone()
237
+ _v = torch.stack([torch.diag(_v[jj]) for jj in range(_v.shape[0])], dim=0)
238
+ positive_embedding = _s.permute(0, 2, 1) @ _v @ _d # [2, 64, 154] [2, 64] [2, 64, 2048]
239
+ pool_embedding = self.pool_prompt_linear(_pool_embedding[:, None, :])
240
+ embedding = torch.cat([positive_embedding, pool_embedding], dim=1)
241
+ bs = noise.shape[0]
242
+ height, width = noise.shape[-2:]
243
+ embed_8 = embedding
244
+ embed_16 = embedding
245
+ embed_24 = embedding
246
+ scale_8 = self.pre8_linear2(embed_8).mean(1)
247
+ scale_16 = self.pre16_linear2(embed_16).mean(1)
248
+ scale_24 = self.pre24_linear2(embed_24).mean(1)
249
+ embed_8 = self.pre8_linear(embed_8).mean(1)
250
+ embed_16 = self.pre16_linear(embed_16).mean(1)
251
+ embed_24 = self.pre24_linear(embed_24).mean(1)
252
+ embed_last = self.last_linear(embedding).mean(1)
253
+ embed_trans = self.trans_linear(torch.cat([_s, _ori_v[...,None], _d], dim=2)).mean(1)
254
+ skip_embedding = self.skip_connection(self.skip_connection2(embedding).permute(0,2,1)).permute(0,2,1)
255
+ scale_skip, embed_skip = skip_embedding.chunk(2,dim=1)
256
+
257
+ ori_noise = noise * (scale_skip[...,None]) + embed_skip[...,None]
258
+ noise = self.pos_embed(noise)
259
+ noise = noise * (1 + scale_8[:, None, :] + embed_trans[:, None, :]) + embed_8[:, None, :]
260
+ scale_list = [scale_16, scale_24]
261
+ embed_list = [embed_16, embed_24]
262
+ for _ii, block in enumerate(self.transformer_blocks):
263
+ noise = block(noise)
264
+ if len(scale_list)!=0 and len(embed_list)!=0:
265
+ noise = noise * (1 + scale_list[int(_ii//4)][:, None, :] + embed_trans[:, None, :]) + embed_list[int(_ii//4)][:, None, :]
266
+
267
+ hidden_states = noise
268
+ hidden_states = self.norm_out(hidden_states, embed_last)
269
+ hidden_states = self.proj_out(hidden_states)
270
+
271
+ # unpatchify
272
+ patch_size = self.patch_size
273
+ height = height // patch_size
274
+ width = width // patch_size
275
+
276
+ hidden_states = hidden_states.reshape(
277
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
278
+ )
279
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
280
+ output = hidden_states.reshape(
281
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
282
+ )
283
+ return output + ori_noise
284
+
285
+ def weak_load_state_dict(self, state_dict: os.Mapping[str, torch.any], strict: bool = True, assign: bool = False):
286
+ return load_filtered_state_dict(self, state_dict)
287
+
288
+
289
+ class NoisePromptDataset(Dataset):
290
+ def __init__(self, if_weight=False):
291
+
292
+ self.if_weight = if_weight
293
+ json_list = glob('/home/xiedian/total_datacollect/json/*.json')
294
+ self.original_score = []
295
+ self.optim_score = []
296
+ self.prompt = []
297
+ self.noise_paths = []
298
+ self.mask_conditions = []
299
+ self.embeddings = []
300
+ counter = 0
301
+ for i in range(len(json_list)):
302
+ with open('//home/xiedian/total_datacollect/json/new{:06d}.json'.format(i), 'r') as f:
303
+ data = json.load(f)
304
+ self.original_score.append(data['original_score_list'])
305
+ self.optim_score.append(data['optimized_score_list'])
306
+ if data['optimized_score_list']>data['original_score_list']:
307
+ counter += 1
308
+ self.prompt.append(data['caption'])
309
+ self.noise_paths.append('/home/xiedian/total_datacollect/latents/{:06d}.pt'.format(i))
310
+ self.embeddings.append('/data/xiedian/hg_reward/newdatacollect/total_datacollect_step_1_1/embedding/embeds_{:06d}.pt'.format(i))
311
+ z = [0, 1] * ((512+77+77) // 2)
312
+ self.mask_conditions.append(data['mid_token_ids'] if 'mid_token_ids' in data else z)
313
+ # while counter * 2 > len(self.prompt):
314
+ # p = random.randint(0,len(self.prompt)-1)
315
+ # if self.original_score[p] > self.optim_score[p]:
316
+ # self.optim_score.append(self.optim_score[p])
317
+ # self.original_score.append(self.original_score[p])
318
+ # self.mask_conditions.append(self.mask_conditions[p])
319
+ # self.noise_paths.append(self.noise_paths[p])
320
+ # self.prompt.append(self.prompt[p])
321
+
322
+ # while counter * 2 < len(self.prompt):
323
+ # p = random.randint(0,len(self.prompt)-1)
324
+ # if self.original_score[p] > self.optim_score[p]:
325
+ # self.optim_score.append(self.optim_score[p])
326
+ # self.original_score.append(self.original_score[p])
327
+ # self.mask_conditions.append(self.mask_conditions[p])
328
+ # self.noise_paths.append(self.noise_paths[p])
329
+ # self.prompt.append(self.prompt[p])
330
+
331
+ self.original_score = torch.Tensor(self.original_score)
332
+ self.optim_score = torch.Tensor(self.optim_score)
333
+
334
+ def __len__(self):
335
+ return len(self.prompt)
336
+
337
+ def __getitem__(self, index):
338
+ try:
339
+ noise = torch.load(self.noise_paths[index], map_location='cpu').squeeze(0).float()
340
+ noise_pred_uncond, mid_noise_pred, noise_pred_text = noise.chunk(3,dim=0)
341
+ prompt = self.prompt[index]
342
+ original_score = self.original_score[index]
343
+ optim_score = self.optim_score[index]
344
+ embedding = torch.load(self.embeddings[index], map_location='cpu')
345
+ _s, _v, _d, _pool_embedding = embedding['_s'], embedding['_v'], embedding['_d'], embedding['pooled_prompt_embeds']
346
+ _s = _s.detach().float()
347
+ _v = _v.detach().float()
348
+ _d = _d.detach().float()
349
+ _pool_embedding = _pool_embedding.detach().float()
350
+ if original_score > optim_score:
351
+ noise_pred = noise_pred_uncond + 4.5 * (noise_pred_text - noise_pred_uncond)
352
+ else:
353
+ guidance_scale = 4.5 * 1.6
354
+ diff_text = torch.norm(noise_pred_text - noise_pred_uncond)
355
+ mid_guidance_scale = (diff_text / torch.norm(noise_pred_text - mid_noise_pred)).item()
356
+ guidance_scale_mid = guidance_scale / (2.4 + 1)
357
+ guidance_scale_all = guidance_scale * 2.4 / (2.4 + 1)
358
+ all_mid = (noise_pred_text - mid_noise_pred) * mid_guidance_scale
359
+ all_null = noise_pred_text - noise_pred_uncond
360
+ noise_pred = all_mid * guidance_scale_mid + all_null * guidance_scale_all + (mid_noise_pred + noise_pred_uncond) / 2
361
+ except:
362
+ print("error", index)
363
+ return self.__getitem__((index+1)%len(self.prompt))
364
+ if self.if_weight:
365
+ return noise_pred_text, prompt, noise_pred, 2 / (1+ math.exp((-abs(original_score) + 26.5)/1.7)), _s, _v, _d, _pool_embedding
366
+ else:
367
+ return noise_pred_text, prompt, noise_pred, _s, _v, _d, _pool_embedding
368
+
369
+ class NoisePromptDataset_2_0(Dataset):
370
+ def __init__(self, if_weight=False):
371
+
372
+ self.if_weight = if_weight
373
+ json_list = glob('/data/xiedian/hg_reward/newdatacollect/total_datacollect_step_1_1/json/*.json')
374
+ self.original_score = []
375
+ self.quick_score = []
376
+ self.slow_score = []
377
+ self.prompt = []
378
+ self.noise_paths = []
379
+ self.mask_conditions = []
380
+ self.img_list = []
381
+ self.embeddings = []
382
+
383
+ counter = 0
384
+ for i in range(len(json_list)):
385
+ with open('/data/xiedian/hg_reward/newdatacollect/total_datacollect_step_1_1/json/new{:06d}.json'.format(i), 'r') as f:
386
+ data = json.load(f)
387
+ if (not os.path.exists('/data/xiedian/hg_reward/newdatacollect/total_datacollect_step_1_1/latents/{:06d}.pt'.format(i))) or \
388
+ max(data['original_score_list'], data['quick_score_list'], data['slow_score_list']) != data['original_score_list']:
389
+ continue
390
+ self.original_score.append(data['original_score_list'])
391
+ self.quick_score.append(data['quick_score_list'])
392
+ self.slow_score.append(data['slow_score_list'])
393
+ self.prompt.append(data['caption'])
394
+ self.noise_paths.append('/data/xiedian/hg_reward/newdatacollect/total_datacollect_step_1_1/latents/{:06d}.pt'.format(i))
395
+ z = [0, 1] * ((512+77+77) // 2)
396
+ self.mask_conditions.append(data['mid_token_ids'] if 'mid_token_ids' in data else z)
397
+ if data['original_score_list'] >= max(data['quick_score_list'], data['slow_score_list']):
398
+ self.img_list.append('/data/xiedian/hg_reward/newdatacollect/total_datacollect_step_1_1/optim/original{:06d}.png'.format(i))
399
+ elif data['quick_score_list'] >= max(data['original_score_list'], data['slow_score_list']):
400
+ self.img_list.append('/data/xiedian/hg_reward/newdatacollect/total_datacollect_step_1_1/optim/quick{:06d}.png'.format(i))
401
+ else:
402
+ self.img_list.append('/data/xiedian/hg_reward/newdatacollect/total_datacollect_step_1_1/optim/slow{:06d}.png'.format(i))
403
+ self.embeddings.append('/home/xiedian/total_datacollect/embedding/embeds_{:06d}.pt'.format(i))
404
+ self.original_score = torch.Tensor(self.original_score)
405
+ self.quick_score = torch.Tensor(self.quick_score)
406
+ self.slow_score = torch.Tensor(self.slow_score)
407
+
408
+ def __len__(self):
409
+ return len(self.prompt)
410
+
411
+ def __getitem__(self, index):
412
+ try:
413
+ original_score = self.original_score[index]
414
+ quick_score = self.quick_score[index]
415
+ slow_score = self.slow_score[index]
416
+ original_score = max(max(quick_score, slow_score), original_score)
417
+ embedding = torch.load(self.embeddings[index], map_location='cpu')
418
+ _s, _v, _d, _pool_embedding = embedding['_s'], embedding['_v'], embedding['_d'], embedding['pooled_prompt_embeds']
419
+ _s = _s.detach().float()
420
+ _v = _v.detach().float()
421
+ _d = _d.detach().float()
422
+ _pool_embedding = _pool_embedding.detach().float()
423
+ noise = torch.load(self.noise_paths[index], map_location='cpu').squeeze(0).float()
424
+ noise_pred_text, noise_pred = noise.chunk(2,dim=0)
425
+ prompt = self.prompt[index]
426
+ except:
427
+ print("error", index)
428
+ return self.__getitem__((index+1)%len(self.prompt))
429
+ if self.if_weight:
430
+ return noise_pred_text, prompt, noise_pred, 2 / (1+ math.exp((-abs(original_score) + 26.5)/1.7)), _s, _v, _d, _pool_embedding
431
+ else:
432
+ return noise_pred_text, prompt, noise_pred, _s, _v, _d, _pool_embedding
433
+
434
+ class NoisePromptDataset_3_0(Dataset):
435
+ def __init__(self, if_weight=False):
436
+
437
+ self.if_weight = if_weight
438
+ json_list = glob('/data/xiedian/hg_reward/CFG_TOTAL/total_datacollect/json/*.json')
439
+ self.score = []
440
+ self.prompt = []
441
+ self.noise_paths = []
442
+ self.mask_conditions = []
443
+ self.img_list = []
444
+ self.embeddings = []
445
+
446
+ print(len(json_list))
447
+
448
+ for i in range(len(json_list)):
449
+ if (not os.path.exists('/data/xiedian/hg_reward/CFG_TOTAL/total_datacollect/embedding/{:06d}.pt'.format(i))):
450
+ continue
451
+
452
+ with open('/data/xiedian/hg_reward/CFG_TOTAL/total_datacollect/json/new{:06d}.json'.format(i), 'r') as f:
453
+ data = json.load(f)
454
+ if data['original_score_list'] > data['optimized_score_list']:
455
+ tag = 0
456
+ if (not os.path.exists('/data/xiedian/hg_reward/CFG_TOTAL/total_datacollect/latents/original{:06d}.pt'.format(i))):
457
+ continue
458
+ else:
459
+ tag = 1
460
+ if (not os.path.exists('/data/xiedian/hg_reward/CFG_TOTAL/total_datacollect/latents/new{:06d}.pt'.format(i))):
461
+ continue
462
+
463
+ if tag == 1:
464
+ self.score.append(data['optimized_score_list'])
465
+ self.noise_paths.append('/data/xiedian/hg_reward/CFG_TOTAL/total_datacollect/latents/new{:06d}.pt'.format(i))
466
+ else:
467
+ self.score.append(data['original_score_list'])
468
+ self.noise_paths.append('/data/xiedian/hg_reward/CFG_TOTAL/total_datacollect/latents/original{:06d}.pt'.format(i))
469
+ self.prompt.append(data['caption'])
470
+ self.embeddings.append('/data/xiedian/hg_reward/CFG_TOTAL/total_datacollect/embedding/{:06d}.pt'.format(i))
471
+ self.score = torch.Tensor(self.score)
472
+
473
+ def __len__(self):
474
+ return len(self.prompt)
475
+
476
+ def __getitem__(self, index):
477
+ try:
478
+ embedding = torch.load(self.embeddings[index], map_location='cpu')
479
+ _s, _v, _d, _pool_embedding = embedding['_s'], embedding['_v'], embedding['_d'], embedding['_pooled_prompt_embeds']
480
+ _s = _s.detach().float()
481
+ _v = _v.detach().float()
482
+ _d = _d.detach().float()
483
+ _pool_embedding = _pool_embedding.detach().float()
484
+ noise = torch.load(self.noise_paths[index], map_location='cpu').float() # [2XT, 16, 128, 128]
485
+ prompt = self.prompt[index] # [ori, target, ori]
486
+ score = self.score[index]
487
+ except:
488
+ print("error", index)
489
+ return self.__getitem__((index+1)%len(self.prompt))
490
+ if self.if_weight:
491
+ return noise, prompt, 2 / (1+ math.exp((-abs(score) + 26.5)/1.7)), _s, _v, _d, _pool_embedding
492
+ else:
493
+ return noise, prompt, _s, _v, _d, _pool_embedding
494
+
495
+
496
+ def load_filtered_state_dict(model, state_dict):
497
+ model_state_dict = model.state_dict()
498
+ filtered_state_dict = {}
499
+ for k, v in state_dict.items():
500
+ if k in model_state_dict:
501
+ if model_state_dict[k].size() == v.size():
502
+ filtered_state_dict[k] = v
503
+ else:
504
+ print(f"Skipping {k}: shape mismatch ({model_state_dict[k].size()} vs {v.size()})")
505
+ else:
506
+ print(f"Skipping {k}: not found in model's state_dict.")
507
+ model.load_state_dict(filtered_state_dict, strict=False)
508
+ return model
509
+
510
+ def custom_collate_fn_2_0(batch):
511
+ noise_pred_texts, prompts, noise_preds, max_scores = zip(*batch)
512
+
513
+ noise_pred_texts = torch.stack(noise_pred_texts)
514
+ noise_preds = torch.stack(noise_preds)
515
+ max_scores = torch.stack(max_scores)
516
+
517
+ return noise_pred_texts, prompts, noise_preds, max_scores
518
+
519
+
520
+ if __name__ == "__main__":
521
+ dataset = NoisePromptDataset(if_weight=True)
522
+ weights = []
523
+ for i, (noise, prompt, gt, weight) in enumerate(dataset):
524
+ weights.append(weight)
525
+ weights = torch.from_numpy(np.array(weights)).cuda()
526
+ print(weights.mean(), weights.std(dim=0))
diffusion_pipeline/sd35_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusion_pipeline/sdxl_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
sample_img.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import math
4
+ import csv
5
+ import random
6
+ import argparse
7
+ import torch
8
+ import os
9
+ import torch.distributed as dist
10
+
11
+ from PIL import Image
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+
14
+ from accelerate.utils import set_seed
15
+
16
+ from diffusion_pipeline.sd35_pipeline import StableDiffusion3Pipeline, FlowMatchEulerInverseScheduler
17
+ from diffusion_pipeline.sdxl_pipeline import StableDiffusionXLPipeline
18
+ from diffusers import BitsAndBytesConfig, SD3Transformer2DModel
19
+ from diffusers import FlowMatchEulerDiscreteScheduler, DDIMInverseScheduler, DDIMScheduler
20
+
21
+ device = torch.device('cuda')
22
+
23
+ def get_args():
24
+ # pick: test_unique_caption_zh.csv draw: drawbench.csv
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--model", default='sd35', choices=['sdxl', 'sd35'], type=str)
27
+ parser.add_argument("--inference-step", default=30, type=int)
28
+ parser.add_argument("--size", default=1024, type=int)
29
+ parser.add_argument("--seed", default=33, type=int)
30
+ parser.add_argument("--cfg", default=3.5, type=float)
31
+
32
+ # hyperparameters for Z-Sampling
33
+ parser.add_argument("--inv-cfg", default=0.5, type=float)
34
+
35
+ # hyperparameters for Z-Core^2
36
+ parser.add_argument("--w2s-guidance", default=1.5, type=float)
37
+ parser.add_argument("--end_timesteps", default=28, type=int) # equal to inference step - 2 or inference step
38
+
39
+
40
+ parser.add_argument("--prompt", default='Mickey Mouse painting by Frank Frazetta.', type=str)
41
+
42
+ parser.add_argument("--method", default='standard', choices=['standard', 'core', 'zigzag', 'z-core'], type=str)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ if __name__ == '__main__':
49
+ torch.cuda.empty_cache()
50
+ dtype = torch.float16
51
+ args = get_args()
52
+ print("args.seed: ", args.seed)
53
+ set_seed(args.seed)
54
+
55
+ # TODO: load pipeline
56
+ if args.model == 'sd35':
57
+ nf4_config = BitsAndBytesConfig(
58
+ load_in_4bit=True,
59
+ bnb_4bit_quant_type="nf4",
60
+ bnb_4bit_compute_dtype=torch.bfloat16
61
+ )
62
+ model_nf4 = SD3Transformer2DModel.from_pretrained(
63
+ "stabilityai/stable-diffusion-3.5-large",
64
+ subfolder="transformer",
65
+ quantization_config=nf4_config,
66
+ torch_dtype=torch.bfloat16
67
+ )
68
+
69
+ pipe = StableDiffusion3Pipeline.from_pretrained(
70
+ "stabilityai/stable-diffusion-3.5-large",
71
+ transformer=model_nf4,
72
+ torch_dtype=torch.bfloat16,
73
+ )
74
+
75
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)
76
+ inverse_scheduler = FlowMatchEulerInverseScheduler.from_pretrained("stabilityai/stable-diffusion-3.5-large",
77
+ subfolder='scheduler')
78
+ pipe.inv_scheduler = inverse_scheduler
79
+
80
+ elif args.model == "sdxl":
81
+ pipe = StableDiffusionXLPipeline.from_pretrained(
82
+ "stabilityai/stable-diffusion-xl-base-1.0",
83
+ torch_dtype=torch.float16,
84
+ variant="fp16",
85
+ use_safetensors=True
86
+ ).to("cuda")
87
+
88
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
89
+ inverse_scheduler = DDIMInverseScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
90
+ subfolder='scheduler')
91
+ pipe.inv_scheduler = inverse_scheduler
92
+
93
+ pipe.to(device)
94
+ pipe.enable_model_cpu_offload()
95
+
96
+ # TODO: load noise model
97
+ if args.method == 'core' or args.method == 'z-core':
98
+ from diffusion_pipeline.refine_model import PromptSD35Net, PromptSDXLNet
99
+ from diffusion_pipeline.lora import replace_linear_with_lora, lora_true
100
+
101
+ if args.model == 'sd35':
102
+ refine_model = PromptSD35Net()
103
+ replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
104
+ lora_true(refine_model, lora_idx=0)
105
+
106
+ checkpoint = torch.load('./weights/sd35_ckpt_v9.pth', map_location='cpu')
107
+ refine_model.load_state_dict(checkpoint)
108
+ elif args.model == 'sdxl':
109
+ refine_model = PromptSDXLNet()
110
+ replace_linear_with_lora(refine_model, rank=48, alpha=1.0, number_of_lora=50)
111
+ lora_true(refine_model, lora_idx=0)
112
+
113
+ checkpoint = torch.load('./weights/sdxl_ckpt_v9.pth', map_location='cpu')
114
+ refine_model.load_state_dict(checkpoint)
115
+
116
+ print("Load Lora Success")
117
+ refine_model = refine_model.to(device)
118
+ refine_model = refine_model.to(torch.bfloat16)
119
+
120
+
121
+ # TODO: load hyperparameters
122
+ size = args.size
123
+ if args.model == 'sdxl':
124
+ shape = (1, 4, size // 8, size // 8)
125
+ else:
126
+ shape = (1, 16, size // 8, size // 8)
127
+
128
+ num_steps = args.inference_step
129
+ end_timesteps = args.end_timesteps
130
+ guidance_scale = args.cfg
131
+ w2s_guidance = args.w2s_guidance
132
+ inv_cfg = args.inv_cfg
133
+ prompt = args.prompt
134
+
135
+ print("pass this prompt: ", prompt)
136
+
137
+ start_latents = torch.randn(shape, dtype=dtype).to(device)
138
+
139
+ if args.model == 'sdxl':
140
+ if args.method == 'core':
141
+ output = pipe.core(
142
+ prompt=prompt,
143
+ guidance_scale=guidance_scale,
144
+ num_inference_steps=num_steps,
145
+ latents=start_latents,
146
+ return_dict=False,
147
+ refine_model=refine_model,
148
+ lora_true=lora_true,
149
+ end_timesteps=end_timesteps,
150
+ w2s_guidance=w2s_guidance)[0][0]
151
+
152
+ elif args.method == 'zigzag':
153
+ output = pipe.zigzag(
154
+ prompt=prompt,
155
+ guidance_scale=guidance_scale,
156
+ latents=start_latents,
157
+ return_dict=False,
158
+ num_inference_steps=num_steps,
159
+ inv_cfg=inv_cfg)[0][0]
160
+
161
+ elif args.method == 'z-core':
162
+ output = pipe.z_core(
163
+ prompt=prompt,
164
+ guidance_scale=guidance_scale,
165
+ num_inference_steps=num_steps,
166
+ latents=start_latents,
167
+ return_dict=False,
168
+ refine_model=refine_model,
169
+ lora_true=lora_true,
170
+ end_timesteps=end_timesteps,
171
+ w2s_guidance=w2s_guidance,
172
+ inv_cfg=inv_cfg)[0][0]
173
+
174
+ elif args.method == 'standard':
175
+ output = pipe(
176
+ prompt=prompt,
177
+ guidance_scale=guidance_scale,
178
+ latents=start_latents,
179
+ return_dict=False,
180
+ num_inference_steps=num_steps)[0][0]
181
+ else:
182
+ raise ValueError("Invalid method")
183
+
184
+ output.save(f'{args.model}_{args.method}.png')
185
+
186
+
187
+ else:
188
+ if args.method == 'core':
189
+ output = pipe.core(
190
+ prompt=prompt,
191
+ guidance_scale=guidance_scale,
192
+ num_inference_steps=num_steps,
193
+ latents=start_latents,
194
+ max_sequence_length=512,
195
+ return_dict=False,
196
+ refine_model=refine_model,
197
+ lora_true=lora_true,
198
+ end_timesteps=end_timesteps,
199
+ w2s_guidance=w2s_guidance)[0][0]
200
+
201
+ elif args.method == 'zigzag':
202
+ output = pipe.zigzag(
203
+ prompt=prompt,
204
+ max_sequence_length=512,
205
+ guidance_scale=guidance_scale,
206
+ latents=start_latents,
207
+ return_dict=False,
208
+ num_inference_steps=num_steps,
209
+ inv_cfg=inv_cfg)[0][0]
210
+
211
+ elif args.method == 'z-core':
212
+ output = pipe.z_core(
213
+ prompt=prompt,
214
+ guidance_scale=guidance_scale,
215
+ num_inference_steps=num_steps,
216
+ latents=start_latents,
217
+ return_dict=False,
218
+ max_sequence_length=512,
219
+ refine_model=refine_model,
220
+ lora_true=lora_true,
221
+ end_timesteps=end_timesteps,
222
+ w2s_guidance=w2s_guidance)[0][0]
223
+
224
+ elif args.method == 'standard':
225
+ output = pipe(
226
+ prompt=prompt,
227
+ guidance_scale=guidance_scale,
228
+ latents=start_latents,
229
+ return_dict=False,
230
+ max_sequence_length=512,
231
+ num_inference_steps=num_steps)[0][0]
232
+ else:
233
+ raise ValueError("Invalid method")
234
+
235
+ output.save(f'{args.model}_{args.method}.png')
236
+
237
+