feat(project): init SVGDreamer
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ImageReward/ImageReward.py +177 -0
- ImageReward/ReFL.py +830 -0
- ImageReward/__init__.py +3 -0
- ImageReward/models/AestheticScore.py +95 -0
- ImageReward/models/BLIP/__init__.py +1 -0
- ImageReward/models/BLIP/blip.py +70 -0
- ImageReward/models/BLIP/blip_pretrain.py +43 -0
- ImageReward/models/BLIP/med.py +947 -0
- ImageReward/models/BLIP/vit.py +301 -0
- ImageReward/models/BLIPScore.py +97 -0
- ImageReward/models/CLIPScore.py +78 -0
- ImageReward/models/__init__.py +4 -0
- ImageReward/utils.py +184 -0
- README.md +86 -12
- assets/Icon-SydneyOperaHouse/init_p0.svg +0 -0
- assets/Icon-SydneyOperaHouse/init_p1.svg +0 -0
- assets/Icon-SydneyOperaHouse/init_p2.svg +0 -0
- assets/Icon-SydneyOperaHouse/init_p3.svg +0 -0
- assets/Icon-SydneyOperaHouse/init_p4.svg +0 -0
- assets/Icon-SydneyOperaHouse/init_p5.svg +0 -0
- assets/Icon-SydneyOperaHouse/p_0.svg +0 -0
- assets/Icon-SydneyOperaHouse/p_1.svg +0 -0
- assets/Icon-SydneyOperaHouse/p_2.svg +0 -0
- assets/Icon-SydneyOperaHouse/p_3.svg +0 -0
- assets/Icon-SydneyOperaHouse/p_4.svg +0 -0
- assets/Icon-SydneyOperaHouse/p_5.svg +0 -0
- assets/{teaser1.png → illustrate.png} +2 -2
- assets/{teaser2.png → teaser_cases.png} +2 -2
- assets/{teaser3.png → teaser_more_cases.png} +2 -2
- assets/teaser_svg_asset.png +3 -0
- conf/config.yaml +54 -0
- conf/x/iconography.yaml +188 -0
- conf/x/ink.yaml +188 -0
- conf/x/lowpoly.yaml +188 -0
- conf/x/painting.yaml +188 -0
- conf/x/pixelart.yaml +188 -0
- conf/x/sketch.yaml +188 -0
- svgdreamer.py +42 -0
- svgdreamer/__init__.py +6 -0
- svgdreamer/diffusers_warp/__init__.py +248 -0
- svgdreamer/diffvg_warp/__init__.py +11 -0
- svgdreamer/diffvg_warp/diffvg_state.py +299 -0
- svgdreamer/libs/__init__.py +8 -0
- svgdreamer/libs/logging.py +65 -0
- svgdreamer/libs/model_state.py +253 -0
- svgdreamer/libs/optim.py +58 -0
- svgdreamer/painter/VPSD_pipeline.py +585 -0
- svgdreamer/painter/__init__.py +10 -0
- svgdreamer/painter/component_painter_params.py +610 -0
- svgdreamer/painter/diffusion_pipeline.py +402 -0
ImageReward/ImageReward.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : ImageReward.py
|
3 |
+
@Time : 2023/01/28 19:53:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: ImageReward Reward model.
|
7 |
+
* Based on CLIP code base and improved-aesthetic-predictor code base
|
8 |
+
* https://github.com/openai/CLIP
|
9 |
+
* https://github.com/christophschuhmann/improved-aesthetic-predictor
|
10 |
+
'''
|
11 |
+
|
12 |
+
import os
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from PIL import Image
|
16 |
+
from .models.BLIP.blip_pretrain import BLIP_Pretrain
|
17 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torchvision.transforms import InterpolationMode
|
21 |
+
|
22 |
+
BICUBIC = InterpolationMode.BICUBIC
|
23 |
+
except ImportError:
|
24 |
+
BICUBIC = Image.BICUBIC
|
25 |
+
|
26 |
+
|
27 |
+
def _convert_image_to_rgb(image):
|
28 |
+
return image.convert("RGB")
|
29 |
+
|
30 |
+
|
31 |
+
def _transform(n_px):
|
32 |
+
return Compose([
|
33 |
+
Resize(n_px, interpolation=BICUBIC),
|
34 |
+
CenterCrop(n_px),
|
35 |
+
_convert_image_to_rgb,
|
36 |
+
ToTensor(),
|
37 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
38 |
+
])
|
39 |
+
|
40 |
+
|
41 |
+
class MLP(nn.Module):
|
42 |
+
def __init__(self, input_size):
|
43 |
+
super().__init__()
|
44 |
+
self.input_size = input_size
|
45 |
+
|
46 |
+
self.layers = nn.Sequential(
|
47 |
+
nn.Linear(self.input_size, 1024),
|
48 |
+
# nn.ReLU(),
|
49 |
+
nn.Dropout(0.2),
|
50 |
+
nn.Linear(1024, 128),
|
51 |
+
# nn.ReLU(),
|
52 |
+
nn.Dropout(0.2),
|
53 |
+
nn.Linear(128, 64),
|
54 |
+
# nn.ReLU(),
|
55 |
+
nn.Dropout(0.1),
|
56 |
+
nn.Linear(64, 16),
|
57 |
+
# nn.ReLU(),
|
58 |
+
nn.Linear(16, 1)
|
59 |
+
)
|
60 |
+
|
61 |
+
# initial MLP param
|
62 |
+
for name, param in self.layers.named_parameters():
|
63 |
+
if 'weight' in name:
|
64 |
+
nn.init.normal_(param, mean=0.0, std=1.0 / (self.input_size + 1))
|
65 |
+
if 'bias' in name:
|
66 |
+
nn.init.constant_(param, val=0)
|
67 |
+
|
68 |
+
def forward(self, input):
|
69 |
+
return self.layers(input)
|
70 |
+
|
71 |
+
|
72 |
+
class ImageReward(nn.Module):
|
73 |
+
def __init__(self, med_config, device='cpu'):
|
74 |
+
super().__init__()
|
75 |
+
self.device = device
|
76 |
+
|
77 |
+
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
|
78 |
+
self.preprocess = _transform(224)
|
79 |
+
self.mlp = MLP(768)
|
80 |
+
|
81 |
+
self.mean = 0.16717362830052426
|
82 |
+
self.std = 1.0333394966054072
|
83 |
+
|
84 |
+
def score_gard(self, prompt_ids, prompt_attention_mask, image):
|
85 |
+
|
86 |
+
image_embeds = self.blip.visual_encoder(image)
|
87 |
+
# text encode cross attention with image
|
88 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
89 |
+
text_output = self.blip.text_encoder(prompt_ids,
|
90 |
+
attention_mask=prompt_attention_mask,
|
91 |
+
encoder_hidden_states=image_embeds,
|
92 |
+
encoder_attention_mask=image_atts,
|
93 |
+
return_dict=True,
|
94 |
+
)
|
95 |
+
|
96 |
+
txt_features = text_output.last_hidden_state[:, 0, :] # (feature_dim)
|
97 |
+
rewards = self.mlp(txt_features)
|
98 |
+
rewards = (rewards - self.mean) / self.std
|
99 |
+
|
100 |
+
return rewards
|
101 |
+
|
102 |
+
def score(self, prompt, image):
|
103 |
+
|
104 |
+
if (type(image).__name__ == 'list'):
|
105 |
+
_, rewards = self.inference_rank(prompt, image)
|
106 |
+
return rewards
|
107 |
+
|
108 |
+
# text encode
|
109 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35,
|
110 |
+
return_tensors="pt").to(self.device)
|
111 |
+
|
112 |
+
# image encode
|
113 |
+
if isinstance(image, Image.Image):
|
114 |
+
pil_image = image
|
115 |
+
elif isinstance(image, str):
|
116 |
+
if os.path.isfile(image):
|
117 |
+
pil_image = Image.open(image)
|
118 |
+
else:
|
119 |
+
raise TypeError(
|
120 |
+
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
|
121 |
+
|
122 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
123 |
+
image_embeds = self.blip.visual_encoder(image)
|
124 |
+
|
125 |
+
# text encode cross attention with image
|
126 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
127 |
+
text_output = self.blip.text_encoder(text_input.input_ids,
|
128 |
+
attention_mask=text_input.attention_mask,
|
129 |
+
encoder_hidden_states=image_embeds,
|
130 |
+
encoder_attention_mask=image_atts,
|
131 |
+
return_dict=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim)
|
135 |
+
rewards = self.mlp(txt_features)
|
136 |
+
rewards = (rewards - self.mean) / self.std
|
137 |
+
|
138 |
+
return rewards.detach().cpu().numpy().item()
|
139 |
+
|
140 |
+
def inference_rank(self, prompt, generations_list):
|
141 |
+
|
142 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35,
|
143 |
+
return_tensors="pt").to(self.device)
|
144 |
+
|
145 |
+
txt_set = []
|
146 |
+
for generation in generations_list:
|
147 |
+
# image encode
|
148 |
+
if isinstance(generation, Image.Image):
|
149 |
+
pil_image = generation
|
150 |
+
elif isinstance(generation, str):
|
151 |
+
if os.path.isfile(generation):
|
152 |
+
pil_image = Image.open(generation)
|
153 |
+
else:
|
154 |
+
raise TypeError(
|
155 |
+
r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
|
156 |
+
|
157 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
158 |
+
image_embeds = self.blip.visual_encoder(image)
|
159 |
+
|
160 |
+
# text encode cross attention with image
|
161 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
162 |
+
text_output = self.blip.text_encoder(text_input.input_ids,
|
163 |
+
attention_mask=text_input.attention_mask,
|
164 |
+
encoder_hidden_states=image_embeds,
|
165 |
+
encoder_attention_mask=image_atts,
|
166 |
+
return_dict=True)
|
167 |
+
txt_set.append(text_output.last_hidden_state[:, 0, :])
|
168 |
+
|
169 |
+
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
|
170 |
+
rewards = self.mlp(txt_features) # [image_num, 1]
|
171 |
+
rewards = (rewards - self.mean) / self.std
|
172 |
+
rewards = torch.squeeze(rewards)
|
173 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
174 |
+
_, indices = torch.sort(rank, dim=0)
|
175 |
+
indices = indices + 1
|
176 |
+
|
177 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
ImageReward/ReFL.py
ADDED
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : ReFL.py
|
3 |
+
@Time : 2023/05/01 19:36:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: ReFL Algorithm.
|
7 |
+
* Based on diffusers code base
|
8 |
+
* https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
|
9 |
+
'''
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import logging
|
13 |
+
import math
|
14 |
+
import os
|
15 |
+
import random
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import accelerate
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch.utils.checkpoint
|
23 |
+
import transformers
|
24 |
+
from accelerate import Accelerator
|
25 |
+
from accelerate.logging import get_logger
|
26 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
27 |
+
from datasets import load_dataset
|
28 |
+
from huggingface_hub import create_repo, upload_folder
|
29 |
+
from packaging import version
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
32 |
+
|
33 |
+
from PIL import Image
|
34 |
+
import ImageReward as RM
|
35 |
+
|
36 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
|
37 |
+
|
38 |
+
try:
|
39 |
+
from torchvision.transforms import InterpolationMode
|
40 |
+
|
41 |
+
BICUBIC = InterpolationMode.BICUBIC
|
42 |
+
except ImportError:
|
43 |
+
BICUBIC = Image.BICUBIC
|
44 |
+
|
45 |
+
import diffusers
|
46 |
+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
47 |
+
from diffusers.optimization import get_scheduler
|
48 |
+
from diffusers.training_utils import EMAModel
|
49 |
+
from diffusers.utils import check_min_version, deprecate
|
50 |
+
from diffusers.utils.import_utils import is_xformers_available
|
51 |
+
|
52 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
53 |
+
check_min_version("0.16.0.dev0")
|
54 |
+
|
55 |
+
logger = get_logger(__name__, log_level="INFO")
|
56 |
+
|
57 |
+
DATASET_NAME_MAPPING = {
|
58 |
+
"refl": ("image", "text"),
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
def parse_args():
|
63 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
64 |
+
parser.add_argument(
|
65 |
+
"--grad_scale", type=float, default=1e-3, help="Scale divided for grad loss value."
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--revision",
|
72 |
+
type=str,
|
73 |
+
default=None,
|
74 |
+
required=False,
|
75 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--dataset_name",
|
79 |
+
type=str,
|
80 |
+
default=None,
|
81 |
+
help=(
|
82 |
+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
83 |
+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
84 |
+
" or to a folder containing files that 🤗 Datasets can understand."
|
85 |
+
),
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--dataset_config_name",
|
89 |
+
type=str,
|
90 |
+
default=None,
|
91 |
+
help="The config of the Dataset, leave as None if there's only one config.",
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--caption_column",
|
98 |
+
type=str,
|
99 |
+
default="text",
|
100 |
+
help="The column of the dataset containing a caption or a list of captions.",
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--max_train_samples",
|
104 |
+
type=int,
|
105 |
+
default=None,
|
106 |
+
help=(
|
107 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
108 |
+
"value if set."
|
109 |
+
),
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--validation_prompts",
|
113 |
+
type=str,
|
114 |
+
default=None,
|
115 |
+
nargs="+",
|
116 |
+
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--output_dir",
|
120 |
+
type=str,
|
121 |
+
default="checkpoint/refl",
|
122 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--cache_dir",
|
126 |
+
type=str,
|
127 |
+
default=None,
|
128 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
129 |
+
)
|
130 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
131 |
+
parser.add_argument(
|
132 |
+
"--resolution",
|
133 |
+
type=int,
|
134 |
+
default=512,
|
135 |
+
help=(
|
136 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
137 |
+
" resolution"
|
138 |
+
),
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--center_crop",
|
142 |
+
default=False,
|
143 |
+
action="store_true",
|
144 |
+
help=(
|
145 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
146 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
147 |
+
),
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--random_flip",
|
151 |
+
action="store_true",
|
152 |
+
help="whether to randomly flip images horizontally",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader."
|
156 |
+
)
|
157 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
158 |
+
parser.add_argument(
|
159 |
+
"--max_train_steps",
|
160 |
+
type=int,
|
161 |
+
default=100,
|
162 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--gradient_accumulation_steps",
|
166 |
+
type=int,
|
167 |
+
default=4,
|
168 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--gradient_checkpointing",
|
172 |
+
action="store_true",
|
173 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--learning_rate",
|
177 |
+
type=float,
|
178 |
+
default=1e-5,
|
179 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--scale_lr",
|
183 |
+
action="store_true",
|
184 |
+
default=False,
|
185 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--lr_scheduler",
|
189 |
+
type=str,
|
190 |
+
default="constant",
|
191 |
+
help=(
|
192 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
193 |
+
' "constant", "constant_with_warmup"]'
|
194 |
+
),
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
198 |
+
)
|
199 |
+
parser.add_argument(
|
200 |
+
"--snr_gamma",
|
201 |
+
type=float,
|
202 |
+
default=None,
|
203 |
+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
204 |
+
"More details here: https://arxiv.org/abs/2303.09556.",
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--allow_tf32",
|
211 |
+
action="store_true",
|
212 |
+
help=(
|
213 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
214 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
215 |
+
),
|
216 |
+
)
|
217 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
218 |
+
parser.add_argument(
|
219 |
+
"--non_ema_revision",
|
220 |
+
type=str,
|
221 |
+
default=None,
|
222 |
+
required=False,
|
223 |
+
help=(
|
224 |
+
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
|
225 |
+
" remote repository specified with --pretrained_model_name_or_path."
|
226 |
+
),
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--dataloader_num_workers",
|
230 |
+
type=int,
|
231 |
+
default=0,
|
232 |
+
help=(
|
233 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
234 |
+
),
|
235 |
+
)
|
236 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
237 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
238 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
239 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
240 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
241 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
242 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
243 |
+
parser.add_argument(
|
244 |
+
"--hub_model_id",
|
245 |
+
type=str,
|
246 |
+
default=None,
|
247 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--logging_dir",
|
251 |
+
type=str,
|
252 |
+
default="logs",
|
253 |
+
help=(
|
254 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
255 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
256 |
+
),
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--mixed_precision",
|
260 |
+
type=str,
|
261 |
+
default=None,
|
262 |
+
choices=["no", "fp16", "bf16"],
|
263 |
+
help=(
|
264 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
265 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
266 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
267 |
+
),
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--report_to",
|
271 |
+
type=str,
|
272 |
+
default="tensorboard",
|
273 |
+
help=(
|
274 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
275 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
276 |
+
),
|
277 |
+
)
|
278 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
279 |
+
parser.add_argument(
|
280 |
+
"--checkpointing_steps",
|
281 |
+
type=int,
|
282 |
+
default=100,
|
283 |
+
help=(
|
284 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
285 |
+
" training using `--resume_from_checkpoint`."
|
286 |
+
),
|
287 |
+
)
|
288 |
+
parser.add_argument(
|
289 |
+
"--checkpoints_total_limit",
|
290 |
+
type=int,
|
291 |
+
default=None,
|
292 |
+
help=(
|
293 |
+
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
294 |
+
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
295 |
+
" for more docs"
|
296 |
+
),
|
297 |
+
)
|
298 |
+
parser.add_argument(
|
299 |
+
"--resume_from_checkpoint",
|
300 |
+
type=str,
|
301 |
+
default=None,
|
302 |
+
help=(
|
303 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
304 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
305 |
+
),
|
306 |
+
)
|
307 |
+
parser.add_argument(
|
308 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
309 |
+
)
|
310 |
+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
311 |
+
parser.add_argument(
|
312 |
+
"--validation_epochs",
|
313 |
+
type=int,
|
314 |
+
default=5,
|
315 |
+
help="Run validation every X epochs.",
|
316 |
+
)
|
317 |
+
parser.add_argument(
|
318 |
+
"--tracker_project_name",
|
319 |
+
type=str,
|
320 |
+
default="text2image-refl",
|
321 |
+
help=(
|
322 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
323 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
324 |
+
),
|
325 |
+
)
|
326 |
+
|
327 |
+
args = parser.parse_args()
|
328 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
329 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
330 |
+
args.local_rank = env_local_rank
|
331 |
+
|
332 |
+
# default to using the same revision for the non-ema model if not specified
|
333 |
+
if args.non_ema_revision is None:
|
334 |
+
args.non_ema_revision = args.revision
|
335 |
+
|
336 |
+
return args
|
337 |
+
|
338 |
+
|
339 |
+
class Trainer(object):
|
340 |
+
|
341 |
+
def __init__(self, pretrained_model_name_or_path, train_data_dir, args):
|
342 |
+
|
343 |
+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
|
344 |
+
self.train_data_dir = train_data_dir
|
345 |
+
|
346 |
+
# Sanity checks
|
347 |
+
if args.dataset_name is None and self.train_data_dir is None:
|
348 |
+
raise ValueError("Need either a dataset name or a training folder.")
|
349 |
+
|
350 |
+
if args.non_ema_revision is not None:
|
351 |
+
deprecate(
|
352 |
+
"non_ema_revision!=None",
|
353 |
+
"0.15.0",
|
354 |
+
message=(
|
355 |
+
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
|
356 |
+
" use `--variant=non_ema` instead."
|
357 |
+
),
|
358 |
+
)
|
359 |
+
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
360 |
+
|
361 |
+
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
|
362 |
+
|
363 |
+
self.accelerator = Accelerator(
|
364 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
365 |
+
mixed_precision=args.mixed_precision,
|
366 |
+
log_with=args.report_to,
|
367 |
+
logging_dir=logging_dir,
|
368 |
+
project_config=accelerator_project_config,
|
369 |
+
)
|
370 |
+
|
371 |
+
# Make one log on every process with the configuration for debugging.
|
372 |
+
logging.basicConfig(
|
373 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
374 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
375 |
+
level=logging.INFO,
|
376 |
+
)
|
377 |
+
logger.info(self.accelerator.state, main_process_only=False)
|
378 |
+
if self.accelerator.is_local_main_process:
|
379 |
+
transformers.utils.logging.set_verbosity_warning()
|
380 |
+
diffusers.utils.logging.set_verbosity_info()
|
381 |
+
else:
|
382 |
+
transformers.utils.logging.set_verbosity_error()
|
383 |
+
diffusers.utils.logging.set_verbosity_error()
|
384 |
+
|
385 |
+
# If passed along, set the training seed now.
|
386 |
+
if args.seed is not None:
|
387 |
+
set_seed(args.seed)
|
388 |
+
|
389 |
+
# Handle the repository creation
|
390 |
+
if self.accelerator.is_main_process:
|
391 |
+
if args.output_dir is not None:
|
392 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
393 |
+
|
394 |
+
if args.push_to_hub:
|
395 |
+
self.repo_id = create_repo(
|
396 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
397 |
+
).repo_id
|
398 |
+
|
399 |
+
# Load scheduler, tokenizer and models.
|
400 |
+
self.noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
|
401 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
402 |
+
self.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
403 |
+
)
|
404 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
405 |
+
self.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
406 |
+
)
|
407 |
+
self.vae = AutoencoderKL.from_pretrained(self.pretrained_model_name_or_path, subfolder="vae",
|
408 |
+
revision=args.revision)
|
409 |
+
self.unet = UNet2DConditionModel.from_pretrained(
|
410 |
+
self.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
|
411 |
+
)
|
412 |
+
self.reward_model = RM.load("ImageReward-v1.0", device=self.accelerator.device)
|
413 |
+
|
414 |
+
# Freeze vae and text_encoder
|
415 |
+
self.vae.requires_grad_(False)
|
416 |
+
self.text_encoder.requires_grad_(False)
|
417 |
+
self.reward_model.requires_grad_(False)
|
418 |
+
|
419 |
+
# Create EMA for the unet.
|
420 |
+
if args.use_ema:
|
421 |
+
self.ema_unet = UNet2DConditionModel.from_pretrained(
|
422 |
+
self.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
423 |
+
)
|
424 |
+
self.ema_unet = EMAModel(self.ema_unet.parameters(), model_cls=UNet2DConditionModel,
|
425 |
+
model_config=self.ema_unet.config)
|
426 |
+
|
427 |
+
if args.enable_xformers_memory_efficient_attention:
|
428 |
+
if is_xformers_available():
|
429 |
+
import xformers
|
430 |
+
|
431 |
+
xformers_version = version.parse(xformers.__version__)
|
432 |
+
if xformers_version == version.parse("0.0.16"):
|
433 |
+
logger.warn(
|
434 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
435 |
+
)
|
436 |
+
self.unet.enable_xformers_memory_efficient_attention()
|
437 |
+
else:
|
438 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
439 |
+
|
440 |
+
# `accelerate` 0.16.0 will have better support for customized saving
|
441 |
+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
442 |
+
# create custom saving & loading hooks so that `self.accelerator.save_state(...)` serializes in a nice format
|
443 |
+
def save_model_hook(models, weights, output_dir):
|
444 |
+
if args.use_ema:
|
445 |
+
self.ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
446 |
+
|
447 |
+
for i, model in enumerate(models):
|
448 |
+
model.save_pretrained(os.path.join(output_dir, "unet"))
|
449 |
+
|
450 |
+
# make sure to pop weight so that corresponding model is not saved again
|
451 |
+
weights.pop()
|
452 |
+
|
453 |
+
def load_model_hook(models, input_dir):
|
454 |
+
if args.use_ema:
|
455 |
+
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
|
456 |
+
self.ema_unet.load_state_dict(load_model.state_dict())
|
457 |
+
self.ema_unet.to(self.accelerator.device)
|
458 |
+
del load_model
|
459 |
+
|
460 |
+
for i in range(len(models)):
|
461 |
+
# pop models so that they are not loaded again
|
462 |
+
model = models.pop()
|
463 |
+
|
464 |
+
# load diffusers style into model
|
465 |
+
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
466 |
+
model.register_to_config(**load_model.config)
|
467 |
+
|
468 |
+
model.load_state_dict(load_model.state_dict())
|
469 |
+
del load_model
|
470 |
+
|
471 |
+
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
472 |
+
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
473 |
+
|
474 |
+
if args.gradient_checkpointing:
|
475 |
+
self.unet.enable_gradient_checkpointing()
|
476 |
+
|
477 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
478 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
479 |
+
if args.allow_tf32:
|
480 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
481 |
+
|
482 |
+
if args.scale_lr:
|
483 |
+
args.learning_rate = (
|
484 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * self.accelerator.num_processes
|
485 |
+
)
|
486 |
+
|
487 |
+
# Initialize the optimizer
|
488 |
+
if args.use_8bit_adam:
|
489 |
+
try:
|
490 |
+
import bitsandbytes as bnb
|
491 |
+
except ImportError:
|
492 |
+
raise ImportError(
|
493 |
+
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
494 |
+
)
|
495 |
+
|
496 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
497 |
+
else:
|
498 |
+
optimizer_cls = torch.optim.AdamW
|
499 |
+
|
500 |
+
self.optimizer = optimizer_cls(
|
501 |
+
self.unet.parameters(),
|
502 |
+
lr=args.learning_rate,
|
503 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
504 |
+
weight_decay=args.adam_weight_decay,
|
505 |
+
eps=args.adam_epsilon,
|
506 |
+
)
|
507 |
+
|
508 |
+
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
509 |
+
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
510 |
+
|
511 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
512 |
+
# download the dataset.
|
513 |
+
if args.dataset_name is not None:
|
514 |
+
# Downloading and loading a dataset from the hub.
|
515 |
+
dataset = load_dataset(
|
516 |
+
args.dataset_name,
|
517 |
+
args.dataset_config_name,
|
518 |
+
cache_dir=args.cache_dir,
|
519 |
+
)
|
520 |
+
else:
|
521 |
+
data_files = {}
|
522 |
+
data_files["train"] = self.train_data_dir
|
523 |
+
dataset = load_dataset(
|
524 |
+
"json",
|
525 |
+
data_files=data_files,
|
526 |
+
cache_dir=args.cache_dir,
|
527 |
+
)
|
528 |
+
# See more about loading custom images at
|
529 |
+
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
530 |
+
|
531 |
+
# Preprocessing the datasets.
|
532 |
+
# We need to tokenize inputs and targets.
|
533 |
+
column_names = dataset["train"].column_names
|
534 |
+
|
535 |
+
# Get the column names for input/target.
|
536 |
+
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
537 |
+
if args.image_column is None:
|
538 |
+
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
539 |
+
else:
|
540 |
+
image_column = args.image_column
|
541 |
+
if image_column not in column_names:
|
542 |
+
raise ValueError(
|
543 |
+
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
544 |
+
)
|
545 |
+
if args.caption_column is None:
|
546 |
+
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
547 |
+
else:
|
548 |
+
caption_column = args.caption_column
|
549 |
+
if caption_column not in column_names:
|
550 |
+
raise ValueError(
|
551 |
+
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
552 |
+
)
|
553 |
+
|
554 |
+
# Preprocessing the datasets.
|
555 |
+
# We need to tokenize input captions and transform the images.
|
556 |
+
def tokenize_captions(examples, is_train=True):
|
557 |
+
captions = []
|
558 |
+
for caption in examples[caption_column]:
|
559 |
+
if isinstance(caption, str):
|
560 |
+
captions.append(caption)
|
561 |
+
elif isinstance(caption, (list, np.ndarray)):
|
562 |
+
# take a random caption if there are multiple
|
563 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
564 |
+
else:
|
565 |
+
raise ValueError(
|
566 |
+
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
567 |
+
)
|
568 |
+
inputs = tokenizer(
|
569 |
+
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
|
570 |
+
return_tensors="pt"
|
571 |
+
)
|
572 |
+
return inputs.input_ids
|
573 |
+
|
574 |
+
def preprocess_train(examples):
|
575 |
+
examples["input_ids"] = tokenize_captions(examples)
|
576 |
+
examples["rm_input_ids"] = self.reward_model.blip.tokenizer(examples[caption_column], padding='max_length',
|
577 |
+
truncation=True, max_length=35,
|
578 |
+
return_tensors="pt").input_ids
|
579 |
+
examples["rm_attention_mask"] = self.reward_model.blip.tokenizer(examples[caption_column],
|
580 |
+
padding='max_length', truncation=True,
|
581 |
+
max_length=35,
|
582 |
+
return_tensors="pt").attention_mask
|
583 |
+
return examples
|
584 |
+
|
585 |
+
with self.accelerator.main_process_first():
|
586 |
+
if args.max_train_samples is not None:
|
587 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
588 |
+
# Set the training transforms
|
589 |
+
self.train_dataset = dataset["train"].with_transform(preprocess_train)
|
590 |
+
|
591 |
+
def collate_fn(examples):
|
592 |
+
input_ids = torch.stack([example["input_ids"] for example in examples])
|
593 |
+
rm_input_ids = torch.stack([example["rm_input_ids"] for example in examples])
|
594 |
+
rm_attention_mask = torch.stack([example["rm_attention_mask"] for example in examples])
|
595 |
+
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
596 |
+
rm_input_ids = rm_input_ids.view(-1, rm_input_ids.shape[-1])
|
597 |
+
rm_attention_mask = rm_attention_mask.view(-1, rm_attention_mask.shape[-1])
|
598 |
+
return {"input_ids": input_ids, "rm_input_ids": rm_input_ids, "rm_attention_mask": rm_attention_mask}
|
599 |
+
|
600 |
+
# DataLoaders creation:
|
601 |
+
self.train_dataloader = torch.utils.data.DataLoader(
|
602 |
+
self.train_dataset,
|
603 |
+
shuffle=True,
|
604 |
+
collate_fn=collate_fn,
|
605 |
+
batch_size=args.train_batch_size,
|
606 |
+
num_workers=args.dataloader_num_workers,
|
607 |
+
)
|
608 |
+
|
609 |
+
# Scheduler and math around the number of training steps.
|
610 |
+
overrode_max_train_steps = False
|
611 |
+
self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
|
612 |
+
if args.max_train_steps is None:
|
613 |
+
args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
|
614 |
+
overrode_max_train_steps = True
|
615 |
+
|
616 |
+
self.lr_scheduler = get_scheduler(
|
617 |
+
args.lr_scheduler,
|
618 |
+
optimizer=self.optimizer,
|
619 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
620 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
621 |
+
)
|
622 |
+
|
623 |
+
# Prepare everything with our `self.accelerator`.
|
624 |
+
self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
|
625 |
+
self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler
|
626 |
+
)
|
627 |
+
|
628 |
+
if args.use_ema:
|
629 |
+
self.ema_unet.to(self.accelerator.device)
|
630 |
+
|
631 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
632 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
633 |
+
self.weight_dtype = torch.float32
|
634 |
+
if self.accelerator.mixed_precision == "fp16":
|
635 |
+
self.weight_dtype = torch.float16
|
636 |
+
elif self.accelerator.mixed_precision == "bf16":
|
637 |
+
self.weight_dtype = torch.bfloat16
|
638 |
+
|
639 |
+
# Move text_encode and vae to gpu and cast to self.weight_dtype
|
640 |
+
self.text_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
|
641 |
+
self.vae.to(self.accelerator.device, dtype=self.weight_dtype)
|
642 |
+
self.reward_model.to(self.accelerator.device, dtype=self.weight_dtype)
|
643 |
+
|
644 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
645 |
+
self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
|
646 |
+
if overrode_max_train_steps:
|
647 |
+
args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
|
648 |
+
# Afterwards we recalculate our number of training epochs
|
649 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / self.num_update_steps_per_epoch)
|
650 |
+
|
651 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
652 |
+
# The trackers initializes automatically on the main process.
|
653 |
+
if self.accelerator.is_main_process:
|
654 |
+
tracker_config = dict(vars(args))
|
655 |
+
tracker_config.pop("validation_prompts")
|
656 |
+
self.accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
657 |
+
|
658 |
+
def train(self, args):
|
659 |
+
|
660 |
+
# Train!
|
661 |
+
total_batch_size = args.train_batch_size * self.accelerator.num_processes * args.gradient_accumulation_steps
|
662 |
+
|
663 |
+
logger.info("***** Running training *****")
|
664 |
+
logger.info(f" Num examples = {len(self.train_dataset)}")
|
665 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
666 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
667 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
668 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
669 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
670 |
+
global_step = 0
|
671 |
+
first_epoch = 0
|
672 |
+
|
673 |
+
# Potentially load in the weights and states from a previous save
|
674 |
+
if args.resume_from_checkpoint:
|
675 |
+
if args.resume_from_checkpoint != "latest":
|
676 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
677 |
+
else:
|
678 |
+
# Get the most recent checkpoint
|
679 |
+
dirs = os.listdir(args.output_dir)
|
680 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
681 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
682 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
683 |
+
|
684 |
+
if path is None:
|
685 |
+
self.accelerator.print(
|
686 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
687 |
+
)
|
688 |
+
args.resume_from_checkpoint = None
|
689 |
+
else:
|
690 |
+
self.accelerator.print(f"Resuming from checkpoint {path}")
|
691 |
+
self.accelerator.load_state(os.path.join(args.output_dir, path))
|
692 |
+
global_step = int(path.split("-")[1])
|
693 |
+
|
694 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
695 |
+
first_epoch = global_step // self.num_update_steps_per_epoch
|
696 |
+
resume_step = resume_global_step % (self.num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
697 |
+
|
698 |
+
# Only show the progress bar once on each machine.
|
699 |
+
progress_bar = tqdm(range(global_step, args.max_train_steps),
|
700 |
+
disable=not self.accelerator.is_local_main_process)
|
701 |
+
progress_bar.set_description("Steps")
|
702 |
+
|
703 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
704 |
+
self.unet.train()
|
705 |
+
train_loss = 0.0
|
706 |
+
for step, batch in enumerate(self.train_dataloader):
|
707 |
+
# Skip steps until we reach the resumed step
|
708 |
+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
709 |
+
if step % args.gradient_accumulation_steps == 0:
|
710 |
+
progress_bar.update(1)
|
711 |
+
continue
|
712 |
+
|
713 |
+
with self.accelerator.accumulate(self.unet):
|
714 |
+
encoder_hidden_states = self.text_encoder(batch["input_ids"])[0]
|
715 |
+
latents = torch.randn((args.train_batch_size, 4, 64, 64), device=self.accelerator.device)
|
716 |
+
|
717 |
+
self.noise_scheduler.set_timesteps(40, device=self.accelerator.device)
|
718 |
+
timesteps = self.noise_scheduler.timesteps
|
719 |
+
|
720 |
+
mid_timestep = random.randint(30, 39)
|
721 |
+
|
722 |
+
for i, t in enumerate(timesteps[:mid_timestep]):
|
723 |
+
with torch.no_grad():
|
724 |
+
latent_model_input = latents
|
725 |
+
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
|
726 |
+
noise_pred = self.unet(
|
727 |
+
latent_model_input,
|
728 |
+
t,
|
729 |
+
encoder_hidden_states=encoder_hidden_states,
|
730 |
+
).sample
|
731 |
+
latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
|
732 |
+
|
733 |
+
latent_model_input = latents
|
734 |
+
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input,
|
735 |
+
timesteps[mid_timestep])
|
736 |
+
noise_pred = self.unet(
|
737 |
+
latent_model_input,
|
738 |
+
timesteps[mid_timestep],
|
739 |
+
encoder_hidden_states=encoder_hidden_states,
|
740 |
+
).sample
|
741 |
+
pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[mid_timestep],
|
742 |
+
latents).pred_original_sample.to(self.weight_dtype)
|
743 |
+
|
744 |
+
pred_original_sample = 1 / self.vae.config.scaling_factor * pred_original_sample
|
745 |
+
image = self.vae.decode(pred_original_sample.to(self.weight_dtype)).sample
|
746 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
747 |
+
|
748 |
+
# image encode
|
749 |
+
def _transform():
|
750 |
+
return Compose([
|
751 |
+
Resize(224, interpolation=BICUBIC),
|
752 |
+
CenterCrop(224),
|
753 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
754 |
+
])
|
755 |
+
|
756 |
+
rm_preprocess = _transform()
|
757 |
+
image = rm_preprocess(image).to(self.accelerator.device)
|
758 |
+
|
759 |
+
rewards = self.reward_model.score_gard(batch["rm_input_ids"], batch["rm_attention_mask"], image)
|
760 |
+
loss = F.relu(-rewards + 2)
|
761 |
+
loss = loss.mean() * args.grad_scale
|
762 |
+
|
763 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
764 |
+
avg_loss = self.accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
765 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
766 |
+
|
767 |
+
# Backpropagate
|
768 |
+
self.accelerator.backward(loss)
|
769 |
+
if self.accelerator.sync_gradients:
|
770 |
+
self.accelerator.clip_grad_norm_(self.unet.parameters(), args.max_grad_norm)
|
771 |
+
self.optimizer.step()
|
772 |
+
self.lr_scheduler.step()
|
773 |
+
self.optimizer.zero_grad()
|
774 |
+
|
775 |
+
# Checks if the self.accelerator has performed an optimization step behind the scenes
|
776 |
+
if self.accelerator.sync_gradients:
|
777 |
+
if args.use_ema:
|
778 |
+
self.ema_unet.step(self.unet.parameters())
|
779 |
+
progress_bar.update(1)
|
780 |
+
global_step += 1
|
781 |
+
self.accelerator.log({"train_loss": train_loss}, step=global_step)
|
782 |
+
train_loss = 0.0
|
783 |
+
|
784 |
+
if global_step % args.checkpointing_steps == 0:
|
785 |
+
if self.accelerator.is_main_process:
|
786 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
787 |
+
self.accelerator.save_state(save_path)
|
788 |
+
logger.info(f"Saved state to {save_path}")
|
789 |
+
|
790 |
+
logs = {"step_loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]}
|
791 |
+
progress_bar.set_postfix(**logs)
|
792 |
+
|
793 |
+
if global_step >= args.max_train_steps:
|
794 |
+
break
|
795 |
+
|
796 |
+
if self.accelerator.is_main_process:
|
797 |
+
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
|
798 |
+
if args.use_ema:
|
799 |
+
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
800 |
+
self.ema_unet.store(self.unet.parameters())
|
801 |
+
self.ema_unet.copy_to(self.unet.parameters())
|
802 |
+
if args.use_ema:
|
803 |
+
# Switch back to the original UNet parameters.
|
804 |
+
self.ema_unet.restore(self.unet.parameters())
|
805 |
+
|
806 |
+
# Create the pipeline using the trained modules and save it.
|
807 |
+
self.accelerator.wait_for_everyone()
|
808 |
+
if self.accelerator.is_main_process:
|
809 |
+
self.unet = self.accelerator.unwrap_model(self.unet)
|
810 |
+
if args.use_ema:
|
811 |
+
self.ema_unet.copy_to(self.unet.parameters())
|
812 |
+
|
813 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
814 |
+
self.pretrained_model_name_or_path,
|
815 |
+
text_encoder=self.text_encoder,
|
816 |
+
vae=self.vae,
|
817 |
+
unet=self.unet,
|
818 |
+
revision=args.revision,
|
819 |
+
)
|
820 |
+
pipeline.save_pretrained(args.output_dir)
|
821 |
+
|
822 |
+
if args.push_to_hub:
|
823 |
+
upload_folder(
|
824 |
+
repo_id=self.repo_id,
|
825 |
+
folder_path=args.output_dir,
|
826 |
+
commit_message="End of training",
|
827 |
+
ignore_patterns=["step_*", "epoch_*"],
|
828 |
+
)
|
829 |
+
|
830 |
+
self.accelerator.end_training()
|
ImageReward/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import *
|
2 |
+
from .models import *
|
3 |
+
from .ReFL import *
|
ImageReward/models/AestheticScore.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : AestheticScore.py
|
3 |
+
@Time : 2023/02/12 14:54:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: AestheticScore.
|
7 |
+
* Based on improved-aesthetic-predictor code base
|
8 |
+
* https://github.com/christophschuhmann/improved-aesthetic-predictor
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from PIL import Image
|
15 |
+
import clip
|
16 |
+
|
17 |
+
|
18 |
+
# if you changed the MLP architecture during training, change it also here:
|
19 |
+
class MLP(nn.Module):
|
20 |
+
def __init__(self, input_size):
|
21 |
+
super().__init__()
|
22 |
+
self.input_size = input_size
|
23 |
+
self.layers = nn.Sequential(
|
24 |
+
nn.Linear(self.input_size, 1024),
|
25 |
+
# nn.ReLU(),
|
26 |
+
nn.Dropout(0.2),
|
27 |
+
nn.Linear(1024, 128),
|
28 |
+
# nn.ReLU(),
|
29 |
+
nn.Dropout(0.2),
|
30 |
+
nn.Linear(128, 64),
|
31 |
+
# nn.ReLU(),
|
32 |
+
nn.Dropout(0.1),
|
33 |
+
|
34 |
+
nn.Linear(64, 16),
|
35 |
+
# nn.ReLU(),
|
36 |
+
|
37 |
+
nn.Linear(16, 1)
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.layers(x)
|
42 |
+
|
43 |
+
|
44 |
+
class AestheticScore(nn.Module):
|
45 |
+
def __init__(self, download_root, device='cpu'):
|
46 |
+
super().__init__()
|
47 |
+
self.device = device
|
48 |
+
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
|
49 |
+
download_root=download_root)
|
50 |
+
self.mlp = MLP(768)
|
51 |
+
|
52 |
+
if device == "cpu":
|
53 |
+
self.clip_model.float()
|
54 |
+
else:
|
55 |
+
clip.model.convert_weights(
|
56 |
+
self.clip_model) # Actually this line is unnecessary since clip by default already on float16
|
57 |
+
|
58 |
+
# have clip.logit_scale require no grad.
|
59 |
+
self.clip_model.logit_scale.requires_grad_(False)
|
60 |
+
|
61 |
+
def score(self, prompt, image_path):
|
62 |
+
|
63 |
+
if (type(image_path).__name__ == 'list'):
|
64 |
+
_, rewards = self.inference_rank(prompt, image_path)
|
65 |
+
return rewards
|
66 |
+
|
67 |
+
# image encode
|
68 |
+
pil_image = Image.open(image_path)
|
69 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
70 |
+
image_features = F.normalize(self.clip_model.encode_image(image)).float()
|
71 |
+
|
72 |
+
# score
|
73 |
+
rewards = self.mlp(image_features)
|
74 |
+
|
75 |
+
return rewards.detach().cpu().numpy().item()
|
76 |
+
|
77 |
+
def inference_rank(self, prompt, generations_list):
|
78 |
+
|
79 |
+
img_set = []
|
80 |
+
for generations in generations_list:
|
81 |
+
# image encode
|
82 |
+
img_path = generations
|
83 |
+
pil_image = Image.open(img_path)
|
84 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
85 |
+
image_features = F.normalize(self.clip_model.encode_image(image))
|
86 |
+
img_set.append(image_features)
|
87 |
+
|
88 |
+
img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
|
89 |
+
rewards = self.mlp(img_features)
|
90 |
+
rewards = torch.squeeze(rewards)
|
91 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
92 |
+
_, indices = torch.sort(rank, dim=0)
|
93 |
+
indices = indices + 1
|
94 |
+
|
95 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
ImageReward/models/BLIP/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .blip_pretrain import *
|
ImageReward/models/BLIP/blip.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
3 |
+
'''
|
4 |
+
|
5 |
+
import warnings
|
6 |
+
warnings.filterwarnings("ignore")
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import os
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
from timm.models.hub import download_cached_file
|
12 |
+
from transformers import BertTokenizer
|
13 |
+
from .vit import VisionTransformer, interpolate_pos_embed
|
14 |
+
|
15 |
+
|
16 |
+
def init_tokenizer():
|
17 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
18 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
19 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
20 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
21 |
+
return tokenizer
|
22 |
+
|
23 |
+
|
24 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
25 |
+
|
26 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
27 |
+
if vit=='base':
|
28 |
+
vision_width = 768
|
29 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
30 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
31 |
+
drop_path_rate=0 or drop_path_rate
|
32 |
+
)
|
33 |
+
elif vit=='large':
|
34 |
+
vision_width = 1024
|
35 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
36 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
37 |
+
drop_path_rate=0.1 or drop_path_rate
|
38 |
+
)
|
39 |
+
return visual_encoder, vision_width
|
40 |
+
|
41 |
+
|
42 |
+
def is_url(url_or_filename):
|
43 |
+
parsed = urlparse(url_or_filename)
|
44 |
+
return parsed.scheme in ("http", "https")
|
45 |
+
|
46 |
+
def load_checkpoint(model,url_or_filename):
|
47 |
+
if is_url(url_or_filename):
|
48 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
49 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
50 |
+
elif os.path.isfile(url_or_filename):
|
51 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
52 |
+
else:
|
53 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
54 |
+
|
55 |
+
state_dict = checkpoint['model']
|
56 |
+
|
57 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
58 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
59 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
60 |
+
model.visual_encoder_m)
|
61 |
+
for key in model.state_dict().keys():
|
62 |
+
if key in state_dict.keys():
|
63 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
64 |
+
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
|
65 |
+
del state_dict[key]
|
66 |
+
|
67 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
68 |
+
print('load checkpoint from %s'%url_or_filename)
|
69 |
+
return model,msg
|
70 |
+
|
ImageReward/models/BLIP/blip_pretrain.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
3 |
+
'''
|
4 |
+
|
5 |
+
import transformers
|
6 |
+
transformers.logging.set_verbosity_error()
|
7 |
+
|
8 |
+
from torch import nn
|
9 |
+
import os
|
10 |
+
from .med import BertConfig, BertModel
|
11 |
+
from .blip import create_vit, init_tokenizer
|
12 |
+
|
13 |
+
class BLIP_Pretrain(nn.Module):
|
14 |
+
def __init__(self,
|
15 |
+
med_config = "med_config.json",
|
16 |
+
image_size = 224,
|
17 |
+
vit = 'base',
|
18 |
+
vit_grad_ckpt = False,
|
19 |
+
vit_ckpt_layer = 0,
|
20 |
+
embed_dim = 256,
|
21 |
+
queue_size = 57600,
|
22 |
+
momentum = 0.995,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
27 |
+
image_size (int): input image size
|
28 |
+
vit (str): model size of vision transformer
|
29 |
+
"""
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
33 |
+
|
34 |
+
self.tokenizer = init_tokenizer()
|
35 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
36 |
+
encoder_config.encoder_width = vision_width
|
37 |
+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
38 |
+
|
39 |
+
text_width = self.text_encoder.config.hidden_size
|
40 |
+
|
41 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
42 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
43 |
+
|
ImageReward/models/BLIP/med.py
ADDED
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
3 |
+
* Based on huggingface code base
|
4 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
5 |
+
'''
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor, device, nn
|
12 |
+
import torch.utils.checkpoint
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import CrossEntropyLoss
|
15 |
+
|
16 |
+
from transformers.activations import ACT2FN
|
17 |
+
from transformers.file_utils import (
|
18 |
+
ModelOutput,
|
19 |
+
)
|
20 |
+
from transformers.modeling_outputs import (
|
21 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
22 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
23 |
+
CausalLMOutputWithCrossAttentions,
|
24 |
+
MaskedLMOutput,
|
25 |
+
MultipleChoiceModelOutput,
|
26 |
+
NextSentencePredictorOutput,
|
27 |
+
QuestionAnsweringModelOutput,
|
28 |
+
SequenceClassifierOutput,
|
29 |
+
TokenClassifierOutput,
|
30 |
+
)
|
31 |
+
from transformers.modeling_utils import (
|
32 |
+
PreTrainedModel,
|
33 |
+
apply_chunking_to_forward,
|
34 |
+
find_pruneable_heads_and_indices,
|
35 |
+
prune_linear_layer,
|
36 |
+
)
|
37 |
+
from transformers.utils import logging
|
38 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
39 |
+
|
40 |
+
|
41 |
+
logger = logging.get_logger(__name__)
|
42 |
+
|
43 |
+
|
44 |
+
class BertEmbeddings(nn.Module):
|
45 |
+
"""Construct the embeddings from word and position embeddings."""
|
46 |
+
|
47 |
+
def __init__(self, config):
|
48 |
+
super().__init__()
|
49 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
50 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
51 |
+
|
52 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
53 |
+
# any TensorFlow checkpoint file
|
54 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
55 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
56 |
+
|
57 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
58 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
59 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
60 |
+
|
61 |
+
self.config = config
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
65 |
+
):
|
66 |
+
if input_ids is not None:
|
67 |
+
input_shape = input_ids.size()
|
68 |
+
else:
|
69 |
+
input_shape = inputs_embeds.size()[:-1]
|
70 |
+
|
71 |
+
seq_length = input_shape[1]
|
72 |
+
|
73 |
+
if position_ids is None:
|
74 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
75 |
+
|
76 |
+
if inputs_embeds is None:
|
77 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
78 |
+
|
79 |
+
embeddings = inputs_embeds
|
80 |
+
|
81 |
+
if self.position_embedding_type == "absolute":
|
82 |
+
position_embeddings = self.position_embeddings(position_ids)
|
83 |
+
embeddings += position_embeddings
|
84 |
+
embeddings = self.LayerNorm(embeddings)
|
85 |
+
embeddings = self.dropout(embeddings)
|
86 |
+
return embeddings
|
87 |
+
|
88 |
+
|
89 |
+
class BertSelfAttention(nn.Module):
|
90 |
+
def __init__(self, config, is_cross_attention):
|
91 |
+
super().__init__()
|
92 |
+
self.config = config
|
93 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
94 |
+
raise ValueError(
|
95 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
96 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
97 |
+
)
|
98 |
+
|
99 |
+
self.num_attention_heads = config.num_attention_heads
|
100 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
101 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
102 |
+
|
103 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
104 |
+
if is_cross_attention:
|
105 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
106 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
107 |
+
else:
|
108 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
109 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
110 |
+
|
111 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
112 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
113 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
114 |
+
self.max_position_embeddings = config.max_position_embeddings
|
115 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
116 |
+
self.save_attention = False
|
117 |
+
|
118 |
+
def save_attn_gradients(self, attn_gradients):
|
119 |
+
self.attn_gradients = attn_gradients
|
120 |
+
|
121 |
+
def get_attn_gradients(self):
|
122 |
+
return self.attn_gradients
|
123 |
+
|
124 |
+
def save_attention_map(self, attention_map):
|
125 |
+
self.attention_map = attention_map
|
126 |
+
|
127 |
+
def get_attention_map(self):
|
128 |
+
return self.attention_map
|
129 |
+
|
130 |
+
def transpose_for_scores(self, x):
|
131 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
132 |
+
x = x.view(*new_x_shape)
|
133 |
+
return x.permute(0, 2, 1, 3)
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self,
|
137 |
+
hidden_states,
|
138 |
+
attention_mask=None,
|
139 |
+
head_mask=None,
|
140 |
+
encoder_hidden_states=None,
|
141 |
+
encoder_attention_mask=None,
|
142 |
+
past_key_value=None,
|
143 |
+
output_attentions=False,
|
144 |
+
):
|
145 |
+
mixed_query_layer = self.query(hidden_states)
|
146 |
+
|
147 |
+
# If this is instantiated as a cross-attention module, the keys
|
148 |
+
# and values come from an encoder; the attention mask needs to be
|
149 |
+
# such that the encoder's padding tokens are not attended to.
|
150 |
+
is_cross_attention = encoder_hidden_states is not None
|
151 |
+
|
152 |
+
if is_cross_attention:
|
153 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
154 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
155 |
+
attention_mask = encoder_attention_mask
|
156 |
+
elif past_key_value is not None:
|
157 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
158 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
159 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
160 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
161 |
+
else:
|
162 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
163 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
164 |
+
|
165 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
166 |
+
|
167 |
+
past_key_value = (key_layer, value_layer)
|
168 |
+
|
169 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
170 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
171 |
+
|
172 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
173 |
+
seq_length = hidden_states.size()[1]
|
174 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
175 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
176 |
+
distance = position_ids_l - position_ids_r
|
177 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
178 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
179 |
+
|
180 |
+
if self.position_embedding_type == "relative_key":
|
181 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
182 |
+
attention_scores = attention_scores + relative_position_scores
|
183 |
+
elif self.position_embedding_type == "relative_key_query":
|
184 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
185 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
186 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
187 |
+
|
188 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
189 |
+
if attention_mask is not None:
|
190 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
191 |
+
attention_scores = attention_scores + attention_mask
|
192 |
+
|
193 |
+
# Normalize the attention scores to probabilities.
|
194 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
195 |
+
|
196 |
+
if is_cross_attention and self.save_attention:
|
197 |
+
self.save_attention_map(attention_probs)
|
198 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
199 |
+
|
200 |
+
# This is actually dropping out entire tokens to attend to, which might
|
201 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
202 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
203 |
+
|
204 |
+
# Mask heads if we want to
|
205 |
+
if head_mask is not None:
|
206 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
207 |
+
|
208 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
209 |
+
|
210 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
211 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
212 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
213 |
+
|
214 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
215 |
+
|
216 |
+
outputs = outputs + (past_key_value,)
|
217 |
+
return outputs
|
218 |
+
|
219 |
+
|
220 |
+
class BertSelfOutput(nn.Module):
|
221 |
+
def __init__(self, config):
|
222 |
+
super().__init__()
|
223 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
224 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
225 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
226 |
+
|
227 |
+
def forward(self, hidden_states, input_tensor):
|
228 |
+
hidden_states = self.dense(hidden_states)
|
229 |
+
hidden_states = self.dropout(hidden_states)
|
230 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
231 |
+
return hidden_states
|
232 |
+
|
233 |
+
|
234 |
+
class BertAttention(nn.Module):
|
235 |
+
def __init__(self, config, is_cross_attention=False):
|
236 |
+
super().__init__()
|
237 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
238 |
+
self.output = BertSelfOutput(config)
|
239 |
+
self.pruned_heads = set()
|
240 |
+
|
241 |
+
def prune_heads(self, heads):
|
242 |
+
if len(heads) == 0:
|
243 |
+
return
|
244 |
+
heads, index = find_pruneable_heads_and_indices(
|
245 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
246 |
+
)
|
247 |
+
|
248 |
+
# Prune linear layers
|
249 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
250 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
251 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
252 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
253 |
+
|
254 |
+
# Update hyper params and store pruned heads
|
255 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
256 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
257 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
258 |
+
|
259 |
+
def forward(
|
260 |
+
self,
|
261 |
+
hidden_states,
|
262 |
+
attention_mask=None,
|
263 |
+
head_mask=None,
|
264 |
+
encoder_hidden_states=None,
|
265 |
+
encoder_attention_mask=None,
|
266 |
+
past_key_value=None,
|
267 |
+
output_attentions=False,
|
268 |
+
):
|
269 |
+
self_outputs = self.self(
|
270 |
+
hidden_states,
|
271 |
+
attention_mask,
|
272 |
+
head_mask,
|
273 |
+
encoder_hidden_states,
|
274 |
+
encoder_attention_mask,
|
275 |
+
past_key_value,
|
276 |
+
output_attentions,
|
277 |
+
)
|
278 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
279 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
280 |
+
return outputs
|
281 |
+
|
282 |
+
|
283 |
+
class BertIntermediate(nn.Module):
|
284 |
+
def __init__(self, config):
|
285 |
+
super().__init__()
|
286 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
287 |
+
if isinstance(config.hidden_act, str):
|
288 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
289 |
+
else:
|
290 |
+
self.intermediate_act_fn = config.hidden_act
|
291 |
+
|
292 |
+
def forward(self, hidden_states):
|
293 |
+
hidden_states = self.dense(hidden_states)
|
294 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
295 |
+
return hidden_states
|
296 |
+
|
297 |
+
|
298 |
+
class BertOutput(nn.Module):
|
299 |
+
def __init__(self, config):
|
300 |
+
super().__init__()
|
301 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
302 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
303 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
304 |
+
|
305 |
+
def forward(self, hidden_states, input_tensor):
|
306 |
+
hidden_states = self.dense(hidden_states)
|
307 |
+
hidden_states = self.dropout(hidden_states)
|
308 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
309 |
+
return hidden_states
|
310 |
+
|
311 |
+
|
312 |
+
class BertLayer(nn.Module):
|
313 |
+
def __init__(self, config, layer_num):
|
314 |
+
super().__init__()
|
315 |
+
self.config = config
|
316 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
317 |
+
self.seq_len_dim = 1
|
318 |
+
self.attention = BertAttention(config)
|
319 |
+
self.layer_num = layer_num
|
320 |
+
if self.config.add_cross_attention:
|
321 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
322 |
+
self.intermediate = BertIntermediate(config)
|
323 |
+
self.output = BertOutput(config)
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
hidden_states,
|
328 |
+
attention_mask=None,
|
329 |
+
head_mask=None,
|
330 |
+
encoder_hidden_states=None,
|
331 |
+
encoder_attention_mask=None,
|
332 |
+
past_key_value=None,
|
333 |
+
output_attentions=False,
|
334 |
+
mode=None,
|
335 |
+
):
|
336 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
337 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
338 |
+
self_attention_outputs = self.attention(
|
339 |
+
hidden_states,
|
340 |
+
attention_mask,
|
341 |
+
head_mask,
|
342 |
+
output_attentions=output_attentions,
|
343 |
+
past_key_value=self_attn_past_key_value,
|
344 |
+
)
|
345 |
+
attention_output = self_attention_outputs[0]
|
346 |
+
|
347 |
+
outputs = self_attention_outputs[1:-1]
|
348 |
+
present_key_value = self_attention_outputs[-1]
|
349 |
+
|
350 |
+
if mode=='multimodal':
|
351 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
352 |
+
|
353 |
+
cross_attention_outputs = self.crossattention(
|
354 |
+
attention_output,
|
355 |
+
attention_mask,
|
356 |
+
head_mask,
|
357 |
+
encoder_hidden_states,
|
358 |
+
encoder_attention_mask,
|
359 |
+
output_attentions=output_attentions,
|
360 |
+
)
|
361 |
+
attention_output = cross_attention_outputs[0]
|
362 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
363 |
+
layer_output = apply_chunking_to_forward(
|
364 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
365 |
+
)
|
366 |
+
outputs = (layer_output,) + outputs
|
367 |
+
|
368 |
+
outputs = outputs + (present_key_value,)
|
369 |
+
|
370 |
+
return outputs
|
371 |
+
|
372 |
+
def feed_forward_chunk(self, attention_output):
|
373 |
+
intermediate_output = self.intermediate(attention_output)
|
374 |
+
layer_output = self.output(intermediate_output, attention_output)
|
375 |
+
return layer_output
|
376 |
+
|
377 |
+
|
378 |
+
class BertEncoder(nn.Module):
|
379 |
+
def __init__(self, config):
|
380 |
+
super().__init__()
|
381 |
+
self.config = config
|
382 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
383 |
+
self.gradient_checkpointing = False
|
384 |
+
|
385 |
+
def forward(
|
386 |
+
self,
|
387 |
+
hidden_states,
|
388 |
+
attention_mask=None,
|
389 |
+
head_mask=None,
|
390 |
+
encoder_hidden_states=None,
|
391 |
+
encoder_attention_mask=None,
|
392 |
+
past_key_values=None,
|
393 |
+
use_cache=None,
|
394 |
+
output_attentions=False,
|
395 |
+
output_hidden_states=False,
|
396 |
+
return_dict=True,
|
397 |
+
mode='multimodal',
|
398 |
+
):
|
399 |
+
all_hidden_states = () if output_hidden_states else None
|
400 |
+
all_self_attentions = () if output_attentions else None
|
401 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
402 |
+
|
403 |
+
next_decoder_cache = () if use_cache else None
|
404 |
+
|
405 |
+
for i in range(self.config.num_hidden_layers):
|
406 |
+
layer_module = self.layer[i]
|
407 |
+
if output_hidden_states:
|
408 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
409 |
+
|
410 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
411 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
412 |
+
|
413 |
+
if self.gradient_checkpointing and self.training:
|
414 |
+
|
415 |
+
if use_cache:
|
416 |
+
logger.warn(
|
417 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
418 |
+
)
|
419 |
+
use_cache = False
|
420 |
+
|
421 |
+
def create_custom_forward(module):
|
422 |
+
def custom_forward(*inputs):
|
423 |
+
return module(*inputs, past_key_value, output_attentions)
|
424 |
+
|
425 |
+
return custom_forward
|
426 |
+
|
427 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
428 |
+
create_custom_forward(layer_module),
|
429 |
+
hidden_states,
|
430 |
+
attention_mask,
|
431 |
+
layer_head_mask,
|
432 |
+
encoder_hidden_states,
|
433 |
+
encoder_attention_mask,
|
434 |
+
mode=mode,
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
layer_outputs = layer_module(
|
438 |
+
hidden_states,
|
439 |
+
attention_mask,
|
440 |
+
layer_head_mask,
|
441 |
+
encoder_hidden_states,
|
442 |
+
encoder_attention_mask,
|
443 |
+
past_key_value,
|
444 |
+
output_attentions,
|
445 |
+
mode=mode,
|
446 |
+
)
|
447 |
+
|
448 |
+
hidden_states = layer_outputs[0]
|
449 |
+
if use_cache:
|
450 |
+
next_decoder_cache += (layer_outputs[-1],)
|
451 |
+
if output_attentions:
|
452 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
453 |
+
|
454 |
+
if output_hidden_states:
|
455 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
456 |
+
|
457 |
+
if not return_dict:
|
458 |
+
return tuple(
|
459 |
+
v
|
460 |
+
for v in [
|
461 |
+
hidden_states,
|
462 |
+
next_decoder_cache,
|
463 |
+
all_hidden_states,
|
464 |
+
all_self_attentions,
|
465 |
+
all_cross_attentions,
|
466 |
+
]
|
467 |
+
if v is not None
|
468 |
+
)
|
469 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
470 |
+
last_hidden_state=hidden_states,
|
471 |
+
past_key_values=next_decoder_cache,
|
472 |
+
hidden_states=all_hidden_states,
|
473 |
+
attentions=all_self_attentions,
|
474 |
+
cross_attentions=all_cross_attentions,
|
475 |
+
)
|
476 |
+
|
477 |
+
|
478 |
+
class BertPooler(nn.Module):
|
479 |
+
def __init__(self, config):
|
480 |
+
super().__init__()
|
481 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
482 |
+
self.activation = nn.Tanh()
|
483 |
+
|
484 |
+
def forward(self, hidden_states):
|
485 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
486 |
+
# to the first token.
|
487 |
+
first_token_tensor = hidden_states[:, 0]
|
488 |
+
pooled_output = self.dense(first_token_tensor)
|
489 |
+
pooled_output = self.activation(pooled_output)
|
490 |
+
return pooled_output
|
491 |
+
|
492 |
+
|
493 |
+
class BertPredictionHeadTransform(nn.Module):
|
494 |
+
def __init__(self, config):
|
495 |
+
super().__init__()
|
496 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
497 |
+
if isinstance(config.hidden_act, str):
|
498 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
499 |
+
else:
|
500 |
+
self.transform_act_fn = config.hidden_act
|
501 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
502 |
+
|
503 |
+
def forward(self, hidden_states):
|
504 |
+
hidden_states = self.dense(hidden_states)
|
505 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
506 |
+
hidden_states = self.LayerNorm(hidden_states)
|
507 |
+
return hidden_states
|
508 |
+
|
509 |
+
|
510 |
+
class BertLMPredictionHead(nn.Module):
|
511 |
+
def __init__(self, config):
|
512 |
+
super().__init__()
|
513 |
+
self.transform = BertPredictionHeadTransform(config)
|
514 |
+
|
515 |
+
# The output weights are the same as the input embeddings, but there is
|
516 |
+
# an output-only bias for each token.
|
517 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
518 |
+
|
519 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
520 |
+
|
521 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
522 |
+
self.decoder.bias = self.bias
|
523 |
+
|
524 |
+
def forward(self, hidden_states):
|
525 |
+
hidden_states = self.transform(hidden_states)
|
526 |
+
hidden_states = self.decoder(hidden_states)
|
527 |
+
return hidden_states
|
528 |
+
|
529 |
+
|
530 |
+
class BertOnlyMLMHead(nn.Module):
|
531 |
+
def __init__(self, config):
|
532 |
+
super().__init__()
|
533 |
+
self.predictions = BertLMPredictionHead(config)
|
534 |
+
|
535 |
+
def forward(self, sequence_output):
|
536 |
+
prediction_scores = self.predictions(sequence_output)
|
537 |
+
return prediction_scores
|
538 |
+
|
539 |
+
|
540 |
+
class BertPreTrainedModel(PreTrainedModel):
|
541 |
+
"""
|
542 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
543 |
+
models.
|
544 |
+
"""
|
545 |
+
|
546 |
+
config_class = BertConfig
|
547 |
+
base_model_prefix = "bert"
|
548 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
549 |
+
|
550 |
+
def _init_weights(self, module):
|
551 |
+
""" Initialize the weights """
|
552 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
553 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
554 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
555 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
556 |
+
elif isinstance(module, nn.LayerNorm):
|
557 |
+
module.bias.data.zero_()
|
558 |
+
module.weight.data.fill_(1.0)
|
559 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
560 |
+
module.bias.data.zero_()
|
561 |
+
|
562 |
+
|
563 |
+
class BertModel(BertPreTrainedModel):
|
564 |
+
"""
|
565 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
566 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
567 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
568 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
569 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
570 |
+
input to the forward pass.
|
571 |
+
"""
|
572 |
+
|
573 |
+
def __init__(self, config, add_pooling_layer=True):
|
574 |
+
super().__init__(config)
|
575 |
+
self.config = config
|
576 |
+
|
577 |
+
self.embeddings = BertEmbeddings(config)
|
578 |
+
|
579 |
+
self.encoder = BertEncoder(config)
|
580 |
+
|
581 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
582 |
+
|
583 |
+
self.init_weights()
|
584 |
+
|
585 |
+
|
586 |
+
def get_input_embeddings(self):
|
587 |
+
return self.embeddings.word_embeddings
|
588 |
+
|
589 |
+
def set_input_embeddings(self, value):
|
590 |
+
self.embeddings.word_embeddings = value
|
591 |
+
|
592 |
+
def _prune_heads(self, heads_to_prune):
|
593 |
+
"""
|
594 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
595 |
+
class PreTrainedModel
|
596 |
+
"""
|
597 |
+
for layer, heads in heads_to_prune.items():
|
598 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
599 |
+
|
600 |
+
|
601 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
602 |
+
"""
|
603 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
604 |
+
|
605 |
+
Arguments:
|
606 |
+
attention_mask (:obj:`torch.Tensor`):
|
607 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
608 |
+
input_shape (:obj:`Tuple[int]`):
|
609 |
+
The shape of the input to the model.
|
610 |
+
device: (:obj:`torch.device`):
|
611 |
+
The device of the input to the model.
|
612 |
+
|
613 |
+
Returns:
|
614 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
615 |
+
"""
|
616 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
617 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
618 |
+
if attention_mask.dim() == 3:
|
619 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
620 |
+
elif attention_mask.dim() == 2:
|
621 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
622 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
623 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
624 |
+
if is_decoder:
|
625 |
+
batch_size, seq_length = input_shape
|
626 |
+
|
627 |
+
seq_ids = torch.arange(seq_length, device=device)
|
628 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
629 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
630 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
631 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
632 |
+
|
633 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
634 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
635 |
+
causal_mask = torch.cat(
|
636 |
+
[
|
637 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
638 |
+
causal_mask,
|
639 |
+
],
|
640 |
+
axis=-1,
|
641 |
+
)
|
642 |
+
|
643 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
644 |
+
else:
|
645 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
646 |
+
else:
|
647 |
+
raise ValueError(
|
648 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
649 |
+
input_shape, attention_mask.shape
|
650 |
+
)
|
651 |
+
)
|
652 |
+
|
653 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
654 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
655 |
+
# positions we want to attend and -10000.0 for masked positions.
|
656 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
657 |
+
# effectively the same as removing these entirely.
|
658 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
659 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
660 |
+
return extended_attention_mask
|
661 |
+
|
662 |
+
def forward(
|
663 |
+
self,
|
664 |
+
input_ids=None,
|
665 |
+
attention_mask=None,
|
666 |
+
position_ids=None,
|
667 |
+
head_mask=None,
|
668 |
+
inputs_embeds=None,
|
669 |
+
encoder_embeds=None,
|
670 |
+
encoder_hidden_states=None,
|
671 |
+
encoder_attention_mask=None,
|
672 |
+
past_key_values=None,
|
673 |
+
use_cache=None,
|
674 |
+
output_attentions=None,
|
675 |
+
output_hidden_states=None,
|
676 |
+
return_dict=None,
|
677 |
+
is_decoder=False,
|
678 |
+
mode='multimodal',
|
679 |
+
):
|
680 |
+
r"""
|
681 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
682 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
683 |
+
the model is configured as a decoder.
|
684 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
685 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
686 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
687 |
+
- 1 for tokens that are **not masked**,
|
688 |
+
- 0 for tokens that are **masked**.
|
689 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
690 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
691 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
692 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
693 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
694 |
+
use_cache (:obj:`bool`, `optional`):
|
695 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
696 |
+
decoding (see :obj:`past_key_values`).
|
697 |
+
"""
|
698 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
699 |
+
output_hidden_states = (
|
700 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
701 |
+
)
|
702 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
703 |
+
|
704 |
+
if is_decoder:
|
705 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
706 |
+
else:
|
707 |
+
use_cache = False
|
708 |
+
|
709 |
+
if input_ids is not None and inputs_embeds is not None:
|
710 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
711 |
+
elif input_ids is not None:
|
712 |
+
input_shape = input_ids.size()
|
713 |
+
batch_size, seq_length = input_shape
|
714 |
+
device = input_ids.device
|
715 |
+
elif inputs_embeds is not None:
|
716 |
+
input_shape = inputs_embeds.size()[:-1]
|
717 |
+
batch_size, seq_length = input_shape
|
718 |
+
device = inputs_embeds.device
|
719 |
+
elif encoder_embeds is not None:
|
720 |
+
input_shape = encoder_embeds.size()[:-1]
|
721 |
+
batch_size, seq_length = input_shape
|
722 |
+
device = encoder_embeds.device
|
723 |
+
else:
|
724 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
725 |
+
|
726 |
+
# past_key_values_length
|
727 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
728 |
+
|
729 |
+
if attention_mask is None:
|
730 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
731 |
+
|
732 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
733 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
734 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
735 |
+
device, is_decoder)
|
736 |
+
|
737 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
738 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
739 |
+
if encoder_hidden_states is not None:
|
740 |
+
if type(encoder_hidden_states) == list:
|
741 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
742 |
+
else:
|
743 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
744 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
745 |
+
|
746 |
+
if type(encoder_attention_mask) == list:
|
747 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
748 |
+
elif encoder_attention_mask is None:
|
749 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
750 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
751 |
+
else:
|
752 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
753 |
+
else:
|
754 |
+
encoder_extended_attention_mask = None
|
755 |
+
|
756 |
+
# Prepare head mask if needed
|
757 |
+
# 1.0 in head_mask indicate we keep the head
|
758 |
+
# attention_probs has shape bsz x n_heads x N x N
|
759 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
760 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
761 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
762 |
+
|
763 |
+
if encoder_embeds is None:
|
764 |
+
embedding_output = self.embeddings(
|
765 |
+
input_ids=input_ids,
|
766 |
+
position_ids=position_ids,
|
767 |
+
inputs_embeds=inputs_embeds,
|
768 |
+
past_key_values_length=past_key_values_length,
|
769 |
+
)
|
770 |
+
else:
|
771 |
+
embedding_output = encoder_embeds
|
772 |
+
|
773 |
+
encoder_outputs = self.encoder(
|
774 |
+
embedding_output,
|
775 |
+
attention_mask=extended_attention_mask,
|
776 |
+
head_mask=head_mask,
|
777 |
+
encoder_hidden_states=encoder_hidden_states,
|
778 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
779 |
+
past_key_values=past_key_values,
|
780 |
+
use_cache=use_cache,
|
781 |
+
output_attentions=output_attentions,
|
782 |
+
output_hidden_states=output_hidden_states,
|
783 |
+
return_dict=return_dict,
|
784 |
+
mode=mode,
|
785 |
+
)
|
786 |
+
sequence_output = encoder_outputs[0]
|
787 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
788 |
+
|
789 |
+
if not return_dict:
|
790 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
791 |
+
|
792 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
793 |
+
last_hidden_state=sequence_output,
|
794 |
+
pooler_output=pooled_output,
|
795 |
+
past_key_values=encoder_outputs.past_key_values,
|
796 |
+
hidden_states=encoder_outputs.hidden_states,
|
797 |
+
attentions=encoder_outputs.attentions,
|
798 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
799 |
+
)
|
800 |
+
|
801 |
+
|
802 |
+
|
803 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
804 |
+
|
805 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
806 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
807 |
+
|
808 |
+
def __init__(self, config):
|
809 |
+
super().__init__(config)
|
810 |
+
|
811 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
812 |
+
self.cls = BertOnlyMLMHead(config)
|
813 |
+
|
814 |
+
self.init_weights()
|
815 |
+
|
816 |
+
def get_output_embeddings(self):
|
817 |
+
return self.cls.predictions.decoder
|
818 |
+
|
819 |
+
def set_output_embeddings(self, new_embeddings):
|
820 |
+
self.cls.predictions.decoder = new_embeddings
|
821 |
+
|
822 |
+
def forward(
|
823 |
+
self,
|
824 |
+
input_ids=None,
|
825 |
+
attention_mask=None,
|
826 |
+
position_ids=None,
|
827 |
+
head_mask=None,
|
828 |
+
inputs_embeds=None,
|
829 |
+
encoder_hidden_states=None,
|
830 |
+
encoder_attention_mask=None,
|
831 |
+
labels=None,
|
832 |
+
past_key_values=None,
|
833 |
+
use_cache=None,
|
834 |
+
output_attentions=None,
|
835 |
+
output_hidden_states=None,
|
836 |
+
return_dict=None,
|
837 |
+
return_logits=False,
|
838 |
+
is_decoder=True,
|
839 |
+
reduction='mean',
|
840 |
+
mode='multimodal',
|
841 |
+
):
|
842 |
+
r"""
|
843 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
844 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
845 |
+
the model is configured as a decoder.
|
846 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
847 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
848 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
849 |
+
- 1 for tokens that are **not masked**,
|
850 |
+
- 0 for tokens that are **masked**.
|
851 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
852 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
853 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
854 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
855 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
856 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
857 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
858 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
859 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
860 |
+
use_cache (:obj:`bool`, `optional`):
|
861 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
862 |
+
decoding (see :obj:`past_key_values`).
|
863 |
+
Returns:
|
864 |
+
Example::
|
865 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
866 |
+
>>> import torch
|
867 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
868 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
869 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
870 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
871 |
+
>>> outputs = model(**inputs)
|
872 |
+
>>> prediction_logits = outputs.logits
|
873 |
+
"""
|
874 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
875 |
+
if labels is not None:
|
876 |
+
use_cache = False
|
877 |
+
|
878 |
+
outputs = self.bert(
|
879 |
+
input_ids,
|
880 |
+
attention_mask=attention_mask,
|
881 |
+
position_ids=position_ids,
|
882 |
+
head_mask=head_mask,
|
883 |
+
inputs_embeds=inputs_embeds,
|
884 |
+
encoder_hidden_states=encoder_hidden_states,
|
885 |
+
encoder_attention_mask=encoder_attention_mask,
|
886 |
+
past_key_values=past_key_values,
|
887 |
+
use_cache=use_cache,
|
888 |
+
output_attentions=output_attentions,
|
889 |
+
output_hidden_states=output_hidden_states,
|
890 |
+
return_dict=return_dict,
|
891 |
+
is_decoder=is_decoder,
|
892 |
+
mode=mode,
|
893 |
+
)
|
894 |
+
|
895 |
+
sequence_output = outputs[0]
|
896 |
+
prediction_scores = self.cls(sequence_output)
|
897 |
+
|
898 |
+
if return_logits:
|
899 |
+
return prediction_scores[:, :-1, :].contiguous()
|
900 |
+
|
901 |
+
lm_loss = None
|
902 |
+
if labels is not None:
|
903 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
904 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
905 |
+
labels = labels[:, 1:].contiguous()
|
906 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
907 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
908 |
+
if reduction=='none':
|
909 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
910 |
+
|
911 |
+
if not return_dict:
|
912 |
+
output = (prediction_scores,) + outputs[2:]
|
913 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
914 |
+
|
915 |
+
return CausalLMOutputWithCrossAttentions(
|
916 |
+
loss=lm_loss,
|
917 |
+
logits=prediction_scores,
|
918 |
+
past_key_values=outputs.past_key_values,
|
919 |
+
hidden_states=outputs.hidden_states,
|
920 |
+
attentions=outputs.attentions,
|
921 |
+
cross_attentions=outputs.cross_attentions,
|
922 |
+
)
|
923 |
+
|
924 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
925 |
+
input_shape = input_ids.shape
|
926 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
927 |
+
if attention_mask is None:
|
928 |
+
attention_mask = input_ids.new_ones(input_shape)
|
929 |
+
|
930 |
+
# cut decoder_input_ids if past is used
|
931 |
+
if past is not None:
|
932 |
+
input_ids = input_ids[:, -1:]
|
933 |
+
|
934 |
+
return {
|
935 |
+
"input_ids": input_ids,
|
936 |
+
"attention_mask": attention_mask,
|
937 |
+
"past_key_values": past,
|
938 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
939 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
940 |
+
"is_decoder": True,
|
941 |
+
}
|
942 |
+
|
943 |
+
def _reorder_cache(self, past, beam_idx):
|
944 |
+
reordered_past = ()
|
945 |
+
for layer_past in past:
|
946 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
947 |
+
return reordered_past
|
ImageReward/models/BLIP/vit.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
3 |
+
* Based on timm code base
|
4 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
5 |
+
'''
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from functools import partial
|
11 |
+
|
12 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
13 |
+
from timm.models.registry import register_model
|
14 |
+
from timm.models.layers import trunc_normal_, DropPath
|
15 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
16 |
+
|
17 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
18 |
+
|
19 |
+
class Mlp(nn.Module):
|
20 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
21 |
+
"""
|
22 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
27 |
+
self.act = act_layer()
|
28 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
29 |
+
self.drop = nn.Dropout(drop)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.fc1(x)
|
33 |
+
x = self.act(x)
|
34 |
+
x = self.drop(x)
|
35 |
+
x = self.fc2(x)
|
36 |
+
x = self.drop(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class Attention(nn.Module):
|
41 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
42 |
+
super().__init__()
|
43 |
+
self.num_heads = num_heads
|
44 |
+
head_dim = dim // num_heads
|
45 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
46 |
+
self.scale = qk_scale or head_dim ** -0.5
|
47 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
48 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
49 |
+
self.proj = nn.Linear(dim, dim)
|
50 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
51 |
+
self.attn_gradients = None
|
52 |
+
self.attention_map = None
|
53 |
+
|
54 |
+
def save_attn_gradients(self, attn_gradients):
|
55 |
+
self.attn_gradients = attn_gradients
|
56 |
+
|
57 |
+
def get_attn_gradients(self):
|
58 |
+
return self.attn_gradients
|
59 |
+
|
60 |
+
def save_attention_map(self, attention_map):
|
61 |
+
self.attention_map = attention_map
|
62 |
+
|
63 |
+
def get_attention_map(self):
|
64 |
+
return self.attention_map
|
65 |
+
|
66 |
+
def forward(self, x, register_hook=False):
|
67 |
+
B, N, C = x.shape
|
68 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
69 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
70 |
+
|
71 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
72 |
+
attn = attn.softmax(dim=-1)
|
73 |
+
attn = self.attn_drop(attn)
|
74 |
+
|
75 |
+
if register_hook:
|
76 |
+
self.save_attention_map(attn)
|
77 |
+
attn.register_hook(self.save_attn_gradients)
|
78 |
+
|
79 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
80 |
+
x = self.proj(x)
|
81 |
+
x = self.proj_drop(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
class Block(nn.Module):
|
86 |
+
|
87 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
88 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
89 |
+
super().__init__()
|
90 |
+
self.norm1 = norm_layer(dim)
|
91 |
+
self.attn = Attention(
|
92 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
93 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
94 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
95 |
+
self.norm2 = norm_layer(dim)
|
96 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
97 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
98 |
+
|
99 |
+
if use_grad_checkpointing:
|
100 |
+
self.attn = checkpoint_wrapper(self.attn)
|
101 |
+
self.mlp = checkpoint_wrapper(self.mlp)
|
102 |
+
|
103 |
+
def forward(self, x, register_hook=False):
|
104 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
105 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class VisionTransformer(nn.Module):
|
110 |
+
""" Vision Transformer
|
111 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
112 |
+
https://arxiv.org/abs/2010.11929
|
113 |
+
"""
|
114 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
115 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
116 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
117 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
118 |
+
"""
|
119 |
+
Args:
|
120 |
+
img_size (int, tuple): input image size
|
121 |
+
patch_size (int, tuple): patch size
|
122 |
+
in_chans (int): number of input channels
|
123 |
+
num_classes (int): number of classes for classification head
|
124 |
+
embed_dim (int): embedding dimension
|
125 |
+
depth (int): depth of transformer
|
126 |
+
num_heads (int): number of attention heads
|
127 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
128 |
+
qkv_bias (bool): enable bias for qkv if True
|
129 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
130 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
131 |
+
drop_rate (float): dropout rate
|
132 |
+
attn_drop_rate (float): attention dropout rate
|
133 |
+
drop_path_rate (float): stochastic depth rate
|
134 |
+
norm_layer: (nn.Module): normalization layer
|
135 |
+
"""
|
136 |
+
super().__init__()
|
137 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
138 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
139 |
+
|
140 |
+
self.patch_embed = PatchEmbed(
|
141 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
142 |
+
|
143 |
+
num_patches = self.patch_embed.num_patches
|
144 |
+
|
145 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
146 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
147 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
148 |
+
|
149 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
150 |
+
self.blocks = nn.ModuleList([
|
151 |
+
Block(
|
152 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
153 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
154 |
+
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
155 |
+
)
|
156 |
+
for i in range(depth)])
|
157 |
+
self.norm = norm_layer(embed_dim)
|
158 |
+
|
159 |
+
trunc_normal_(self.pos_embed, std=.02)
|
160 |
+
trunc_normal_(self.cls_token, std=.02)
|
161 |
+
self.apply(self._init_weights)
|
162 |
+
|
163 |
+
def _init_weights(self, m):
|
164 |
+
if isinstance(m, nn.Linear):
|
165 |
+
trunc_normal_(m.weight, std=.02)
|
166 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
167 |
+
nn.init.constant_(m.bias, 0)
|
168 |
+
elif isinstance(m, nn.LayerNorm):
|
169 |
+
nn.init.constant_(m.bias, 0)
|
170 |
+
nn.init.constant_(m.weight, 1.0)
|
171 |
+
|
172 |
+
@torch.jit.ignore
|
173 |
+
def no_weight_decay(self):
|
174 |
+
return {'pos_embed', 'cls_token'}
|
175 |
+
|
176 |
+
def forward(self, x, register_blk=-1):
|
177 |
+
B = x.shape[0]
|
178 |
+
x = self.patch_embed(x)
|
179 |
+
|
180 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
181 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
182 |
+
|
183 |
+
x = x + self.pos_embed[:,:x.size(1),:]
|
184 |
+
x = self.pos_drop(x)
|
185 |
+
|
186 |
+
for i,blk in enumerate(self.blocks):
|
187 |
+
x = blk(x, register_blk==i)
|
188 |
+
x = self.norm(x)
|
189 |
+
|
190 |
+
return x
|
191 |
+
|
192 |
+
@torch.jit.ignore()
|
193 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
194 |
+
_load_weights(self, checkpoint_path, prefix)
|
195 |
+
|
196 |
+
|
197 |
+
@torch.no_grad()
|
198 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
199 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
200 |
+
"""
|
201 |
+
import numpy as np
|
202 |
+
|
203 |
+
def _n2p(w, t=True):
|
204 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
205 |
+
w = w.flatten()
|
206 |
+
if t:
|
207 |
+
if w.ndim == 4:
|
208 |
+
w = w.transpose([3, 2, 0, 1])
|
209 |
+
elif w.ndim == 3:
|
210 |
+
w = w.transpose([2, 0, 1])
|
211 |
+
elif w.ndim == 2:
|
212 |
+
w = w.transpose([1, 0])
|
213 |
+
return torch.from_numpy(w)
|
214 |
+
|
215 |
+
w = np.load(checkpoint_path)
|
216 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
217 |
+
prefix = 'opt/target/'
|
218 |
+
|
219 |
+
if hasattr(model.patch_embed, 'backbone'):
|
220 |
+
# hybrid
|
221 |
+
backbone = model.patch_embed.backbone
|
222 |
+
stem_only = not hasattr(backbone, 'stem')
|
223 |
+
stem = backbone if stem_only else backbone.stem
|
224 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
225 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
226 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
227 |
+
if not stem_only:
|
228 |
+
for i, stage in enumerate(backbone.stages):
|
229 |
+
for j, block in enumerate(stage.blocks):
|
230 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
231 |
+
for r in range(3):
|
232 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
233 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
234 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
235 |
+
if block.downsample is not None:
|
236 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
237 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
238 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
239 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
240 |
+
else:
|
241 |
+
embed_conv_w = adapt_input_conv(
|
242 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
243 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
244 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
245 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
246 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
247 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
248 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
249 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
250 |
+
model.pos_embed.copy_(pos_embed_w)
|
251 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
252 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
253 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
254 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
255 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
256 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
257 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
258 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
259 |
+
for i, block in enumerate(model.blocks.children()):
|
260 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
261 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
262 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
263 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
264 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
265 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
266 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
267 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
268 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
269 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
270 |
+
for r in range(2):
|
271 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
272 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
273 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
274 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
275 |
+
|
276 |
+
|
277 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
278 |
+
# interpolate position embedding
|
279 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
280 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
281 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
282 |
+
# height (== width) for the checkpoint position embedding
|
283 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
284 |
+
# height (== width) for the new position embedding
|
285 |
+
new_size = int(num_patches ** 0.5)
|
286 |
+
|
287 |
+
if orig_size!=new_size:
|
288 |
+
# class_token and dist_token are kept unchanged
|
289 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
290 |
+
# only the position tokens are interpolated
|
291 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
292 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
293 |
+
pos_tokens = torch.nn.functional.interpolate(
|
294 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
295 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
296 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
297 |
+
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
298 |
+
|
299 |
+
return new_pos_embed
|
300 |
+
else:
|
301 |
+
return pos_embed_checkpoint
|
ImageReward/models/BLIPScore.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : BLIPScore.py
|
3 |
+
@Time : 2023/02/19 20:48:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: BLIPScore.
|
7 |
+
* Based on BLIP code base
|
8 |
+
* https://github.com/salesforce/BLIP
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from PIL import Image
|
15 |
+
from ImageReward.models.BLIP.blip_pretrain import BLIP_Pretrain
|
16 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
17 |
+
|
18 |
+
try:
|
19 |
+
from torchvision.transforms import InterpolationMode
|
20 |
+
BICUBIC = InterpolationMode.BICUBIC
|
21 |
+
except ImportError:
|
22 |
+
BICUBIC = Image.BICUBIC
|
23 |
+
|
24 |
+
|
25 |
+
def _convert_image_to_rgb(image):
|
26 |
+
return image.convert("RGB")
|
27 |
+
|
28 |
+
|
29 |
+
def _transform(n_px):
|
30 |
+
return Compose([
|
31 |
+
Resize(n_px, interpolation=BICUBIC),
|
32 |
+
CenterCrop(n_px),
|
33 |
+
_convert_image_to_rgb,
|
34 |
+
ToTensor(),
|
35 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
36 |
+
])
|
37 |
+
|
38 |
+
|
39 |
+
class BLIPScore(nn.Module):
|
40 |
+
def __init__(self, med_config, device='cpu'):
|
41 |
+
super().__init__()
|
42 |
+
self.device = device
|
43 |
+
|
44 |
+
self.preprocess = _transform(224)
|
45 |
+
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
|
46 |
+
|
47 |
+
|
48 |
+
def score(self, prompt, image_path):
|
49 |
+
|
50 |
+
if (type(image_path).__name__=='list'):
|
51 |
+
_, rewards = self.inference_rank(prompt, image_path)
|
52 |
+
return rewards
|
53 |
+
|
54 |
+
# text encode
|
55 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
56 |
+
text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
|
57 |
+
txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:]))
|
58 |
+
|
59 |
+
# image encode
|
60 |
+
pil_image = Image.open(image_path)
|
61 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
62 |
+
image_embeds = self.blip.visual_encoder(image)
|
63 |
+
image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1)
|
64 |
+
|
65 |
+
# score
|
66 |
+
rewards = torch.sum(torch.mul(txt_feature, image_features), dim=1, keepdim=True)
|
67 |
+
|
68 |
+
return rewards.detach().cpu().numpy().item()
|
69 |
+
|
70 |
+
|
71 |
+
def inference_rank(self, prompt, generations_list):
|
72 |
+
|
73 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
74 |
+
text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
|
75 |
+
txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:]))
|
76 |
+
|
77 |
+
txt_set = []
|
78 |
+
img_set = []
|
79 |
+
for generations in generations_list:
|
80 |
+
# image encode
|
81 |
+
img_path = generations
|
82 |
+
pil_image = Image.open(img_path)
|
83 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
84 |
+
image_embeds = self.blip.visual_encoder(image)
|
85 |
+
image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1)
|
86 |
+
img_set.append(image_features)
|
87 |
+
txt_set.append(txt_feature)
|
88 |
+
|
89 |
+
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
|
90 |
+
img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
|
91 |
+
rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
|
92 |
+
rewards = torch.squeeze(rewards)
|
93 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
94 |
+
_, indices = torch.sort(rank, dim=0)
|
95 |
+
indices = indices + 1
|
96 |
+
|
97 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
ImageReward/models/CLIPScore.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : CLIPScore.py
|
3 |
+
@Time : 2023/02/12 13:14:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
@Description: CLIPScore.
|
7 |
+
* Based on CLIP code base
|
8 |
+
* https://github.com/openai/CLIP
|
9 |
+
'''
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from PIL import Image
|
15 |
+
import clip
|
16 |
+
|
17 |
+
class CLIPScore(nn.Module):
|
18 |
+
def __init__(self, download_root, device='cpu'):
|
19 |
+
super().__init__()
|
20 |
+
self.device = device
|
21 |
+
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
|
22 |
+
download_root=download_root)
|
23 |
+
|
24 |
+
if device == "cpu":
|
25 |
+
self.clip_model.float()
|
26 |
+
else:
|
27 |
+
clip.model.convert_weights(self.clip_model) # Actually this line is unnecessary since clip by default already on float16
|
28 |
+
|
29 |
+
# have clip.logit_scale require no grad.
|
30 |
+
self.clip_model.logit_scale.requires_grad_(False)
|
31 |
+
|
32 |
+
|
33 |
+
def score(self, prompt, image_path):
|
34 |
+
|
35 |
+
if (type(image_path).__name__=='list'):
|
36 |
+
_, rewards = self.inference_rank(prompt, image_path)
|
37 |
+
return rewards
|
38 |
+
|
39 |
+
# text encode
|
40 |
+
text = clip.tokenize(prompt, truncate=True).to(self.device)
|
41 |
+
txt_features = F.normalize(self.clip_model.encode_text(text))
|
42 |
+
|
43 |
+
# image encode
|
44 |
+
pil_image = Image.open(image_path)
|
45 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
46 |
+
image_features = F.normalize(self.clip_model.encode_image(image))
|
47 |
+
|
48 |
+
# score
|
49 |
+
rewards = torch.sum(torch.mul(txt_features, image_features), dim=1, keepdim=True)
|
50 |
+
|
51 |
+
return rewards.detach().cpu().numpy().item()
|
52 |
+
|
53 |
+
|
54 |
+
def inference_rank(self, prompt, generations_list):
|
55 |
+
|
56 |
+
text = clip.tokenize(prompt, truncate=True).to(self.device)
|
57 |
+
txt_feature = F.normalize(self.clip_model.encode_text(text))
|
58 |
+
|
59 |
+
txt_set = []
|
60 |
+
img_set = []
|
61 |
+
for generations in generations_list:
|
62 |
+
# image encode
|
63 |
+
img_path = generations
|
64 |
+
pil_image = Image.open(img_path)
|
65 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
66 |
+
image_features = F.normalize(self.clip_model.encode_image(image))
|
67 |
+
img_set.append(image_features)
|
68 |
+
txt_set.append(txt_feature)
|
69 |
+
|
70 |
+
txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
|
71 |
+
img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
|
72 |
+
rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
|
73 |
+
rewards = torch.squeeze(rewards)
|
74 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
75 |
+
_, indices = torch.sort(rank, dim=0)
|
76 |
+
indices = indices + 1
|
77 |
+
|
78 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
ImageReward/models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .AestheticScore import *
|
2 |
+
from .BLIPScore import *
|
3 |
+
from .CLIPScore import *
|
4 |
+
from .BLIP import *
|
ImageReward/utils.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
@File : utils.py
|
3 |
+
@Time : 2023/04/05 19:18:00
|
4 |
+
@Auther : Jiazheng Xu
|
5 |
+
@Contact : [email protected]
|
6 |
+
* Based on CLIP code base
|
7 |
+
* https://github.com/openai/CLIP
|
8 |
+
* Checkpoint of CLIP/BLIP/Aesthetic are from:
|
9 |
+
* https://github.com/openai/CLIP
|
10 |
+
* https://github.com/salesforce/BLIP
|
11 |
+
* https://github.com/christophschuhmann/improved-aesthetic-predictor
|
12 |
+
'''
|
13 |
+
|
14 |
+
import os
|
15 |
+
import urllib
|
16 |
+
from typing import Union, List
|
17 |
+
import pathlib
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from tqdm import tqdm
|
21 |
+
from huggingface_hub import hf_hub_download
|
22 |
+
|
23 |
+
from .ImageReward import ImageReward
|
24 |
+
from .models.CLIPScore import CLIPScore
|
25 |
+
from .models.BLIPScore import BLIPScore
|
26 |
+
from .models.AestheticScore import AestheticScore
|
27 |
+
|
28 |
+
_MODELS = {
|
29 |
+
"ImageReward-v1.0": "https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt",
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def available_models() -> List[str]:
|
34 |
+
"""Returns the names of available ImageReward models"""
|
35 |
+
return list(_MODELS.keys())
|
36 |
+
|
37 |
+
|
38 |
+
def ImageReward_download(url: str, root: str):
|
39 |
+
os.makedirs(root, exist_ok=True)
|
40 |
+
filename = os.path.basename(url)
|
41 |
+
download_target = os.path.join(root, filename)
|
42 |
+
hf_hub_download(repo_id="THUDM/ImageReward", filename=filename, local_dir=root)
|
43 |
+
return download_target
|
44 |
+
|
45 |
+
|
46 |
+
def load(name: str = "ImageReward-v1.0",
|
47 |
+
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
48 |
+
download_root: str = None,
|
49 |
+
med_config_path: str = None):
|
50 |
+
"""Load a ImageReward model
|
51 |
+
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
name: str
|
55 |
+
A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict
|
56 |
+
device: Union[str, torch.device]
|
57 |
+
The device to put the loaded model
|
58 |
+
download_root: str
|
59 |
+
path to download the model files; by default, it uses "~/.cache/ImageReward"
|
60 |
+
med_config_path: str
|
61 |
+
|
62 |
+
Returns
|
63 |
+
-------
|
64 |
+
model : torch.nn.Module
|
65 |
+
The ImageReward model
|
66 |
+
"""
|
67 |
+
if name in _MODELS:
|
68 |
+
download_root = download_root or "~/.cache/ImageReward"
|
69 |
+
download_root = pathlib.Path(download_root)
|
70 |
+
model_path = pathlib.Path(download_root) / 'ImageReward.pt'
|
71 |
+
|
72 |
+
if not model_path.exists():
|
73 |
+
model_path = ImageReward_download(_MODELS[name], root=download_root.as_posix())
|
74 |
+
elif os.path.isfile(name):
|
75 |
+
model_path = name
|
76 |
+
else:
|
77 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
78 |
+
|
79 |
+
print('-> load ImageReward model from %s' % model_path)
|
80 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
81 |
+
|
82 |
+
# med_config
|
83 |
+
if med_config_path is None:
|
84 |
+
med_config_root = download_root or "~/.cache/ImageReward"
|
85 |
+
med_config_root = pathlib.Path(med_config_root)
|
86 |
+
med_config_path = med_config_root / 'med_config.json'
|
87 |
+
|
88 |
+
if not med_config_path.exists():
|
89 |
+
med_config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
|
90 |
+
root=med_config_root.as_posix())
|
91 |
+
print('-> load ImageReward med_config from %s' % med_config_path)
|
92 |
+
|
93 |
+
model = ImageReward(device=device, med_config=med_config_path).to(device)
|
94 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
95 |
+
model.eval()
|
96 |
+
|
97 |
+
return model
|
98 |
+
|
99 |
+
|
100 |
+
_SCORES = {
|
101 |
+
"CLIP": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
102 |
+
"BLIP": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth",
|
103 |
+
"Aesthetic": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth",
|
104 |
+
}
|
105 |
+
|
106 |
+
|
107 |
+
def available_scores() -> List[str]:
|
108 |
+
"""Returns the names of available ImageReward scores"""
|
109 |
+
return list(_SCORES.keys())
|
110 |
+
|
111 |
+
|
112 |
+
def _download(url: str, root: str):
|
113 |
+
os.makedirs(root, exist_ok=True)
|
114 |
+
filename = os.path.basename(url)
|
115 |
+
|
116 |
+
download_target = os.path.join(root, filename)
|
117 |
+
|
118 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
119 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
120 |
+
|
121 |
+
if os.path.isfile(download_target):
|
122 |
+
return download_target
|
123 |
+
|
124 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
125 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
|
126 |
+
unit_divisor=1024) as loop:
|
127 |
+
while True:
|
128 |
+
buffer = source.read(8192)
|
129 |
+
if not buffer:
|
130 |
+
break
|
131 |
+
|
132 |
+
output.write(buffer)
|
133 |
+
loop.update(len(buffer))
|
134 |
+
|
135 |
+
return download_target
|
136 |
+
|
137 |
+
|
138 |
+
def load_score(name: str = "CLIP", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
139 |
+
download_root: str = None):
|
140 |
+
"""Load a ImageReward model
|
141 |
+
|
142 |
+
Parameters
|
143 |
+
----------
|
144 |
+
name : str
|
145 |
+
A model name listed by `ImageReward.available_models()`
|
146 |
+
|
147 |
+
device : Union[str, torch.device]
|
148 |
+
The device to put the loaded model
|
149 |
+
|
150 |
+
download_root: str
|
151 |
+
path to download the model files; by default, it uses "~/.cache/ImageReward"
|
152 |
+
|
153 |
+
Returns
|
154 |
+
-------
|
155 |
+
model : torch.nn.Module
|
156 |
+
The ImageReward model
|
157 |
+
"""
|
158 |
+
model_download_root = download_root or os.path.expanduser("~/.cache/ImageReward")
|
159 |
+
|
160 |
+
if name in _SCORES:
|
161 |
+
model_path = _download(_SCORES[name], model_download_root)
|
162 |
+
else:
|
163 |
+
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
|
164 |
+
|
165 |
+
print('load checkpoint from %s' % model_path)
|
166 |
+
if name == "BLIP":
|
167 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
168 |
+
med_config = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
|
169 |
+
model_download_root)
|
170 |
+
model = BLIPScore(med_config=med_config, device=device).to(device)
|
171 |
+
model.blip.load_state_dict(state_dict['model'], strict=False)
|
172 |
+
elif name == "CLIP":
|
173 |
+
model = CLIPScore(download_root=model_download_root, device=device).to(device)
|
174 |
+
elif name == "Aesthetic":
|
175 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
176 |
+
model = AestheticScore(download_root=model_download_root, device=device).to(device)
|
177 |
+
model.mlp.load_state_dict(state_dict, strict=False)
|
178 |
+
else:
|
179 |
+
raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
|
180 |
+
|
181 |
+
print("checkpoint loaded")
|
182 |
+
model.eval()
|
183 |
+
|
184 |
+
return model
|
README.md
CHANGED
@@ -1,19 +1,91 @@
|
|
1 |
# SVGDreamer: Text Guided SVG Generation with Diffusion Model
|
2 |
|
|
|
3 |
[](https://arxiv.org/abs/2312.16476)
|
4 |
-
[
|
13 |
|
14 |
-
|
15 |
|
16 |
-
- [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
## :books: Acknowledgement
|
19 |
|
@@ -22,6 +94,8 @@ The project is built based on the following repository:
|
|
22 |
- [BachiLi/diffvg](https://github.com/BachiLi/diffvg)
|
23 |
- [huggingface/diffusers](https://github.com/huggingface/diffusers)
|
24 |
- [ximinng/DiffSketcher](https://github.com/ximinng/DiffSketcher)
|
|
|
|
|
25 |
|
26 |
We gratefully thank the authors for their wonderful works.
|
27 |
|
@@ -31,10 +105,10 @@ If you use this code for your research, please cite the following work:
|
|
31 |
|
32 |
```
|
33 |
@article{xing2023svgdreamer,
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
}
|
39 |
```
|
40 |
|
|
|
1 |
# SVGDreamer: Text Guided SVG Generation with Diffusion Model
|
2 |
|
3 |
+
[](https://arxiv.org/abs/2312.16476)
|
4 |
[](https://arxiv.org/abs/2312.16476)
|
5 |
+
[](https://ximinng.github.io/SVGDreamer-project/)
|
6 |
+
[](https://huggingface.co/blog/xingxm/svgdreamer)
|
7 |
+
[](https://huggingface.co/blog/xingxm/svgdreamer)
|
8 |
|
9 |
+
This repository contains our official implementation of the CVPR 2024 paper: SVGDreamer: Text-Guided SVG Generation with
|
10 |
+
Diffusion Model. It can generate high-quality SVGs based on text prompts.
|
11 |
|
12 |
+
[//]: # (> Project Page: https://ximinng.github.io/SVGDreamer-project/)
|
13 |
|
14 |
+

|
15 |
+

|
|
|
16 |
|
17 |
+
## :new: Update
|
18 |
|
19 |
+
- [03/2024] 🔥 We have released the **code** for [SVGDreamer](https://ximinng.github.io/SVGDreamer-project/).
|
20 |
+
- [02/2024] 🎉 **SVGDreamer accepted by CVPR2024.** 🎉
|
21 |
+
- [12/2023] 🔥 We have released the **[SVGDreamer Paper](https://arxiv.org/abs/2312.16476)**. SVGDreamer is
|
22 |
+
a novel text-guided vector graphics synthesis method. This method considers both the editing of vector graphics and
|
23 |
+
the quality of the synthesis.
|
24 |
+
|
25 |
+
## 🔥Quickstart
|
26 |
+
|
27 |
+
Before running the code, download the stable diffusion model. Append `diffuser.download=True` to the end of the script.
|
28 |
+
|
29 |
+
### SIVE + VPSD
|
30 |
+
|
31 |
+
**Script:**
|
32 |
+
|
33 |
+
```shell
|
34 |
+
python svgdreamer.py x=iconography skip_sive=False "prompt='an image of Batman. full body action pose, complete detailed body. white background. empty background, high quality, 4K, ultra realistic'" token_ind=4 x.vpsd.t_schedule='randint' result_path='./logs/batman' multirun=True mv=True
|
35 |
+
```
|
36 |
+
|
37 |
+
- `x=iconography`(str): style configs
|
38 |
+
- `skip_sive`(bool): enable the SIVE stage
|
39 |
+
- `token_ind`(int): the index of text prompt, from 1
|
40 |
+
- `result_path`(str): the path to save the result
|
41 |
+
- `multirun`(bool): run the script multiple times with different random seeds
|
42 |
+
- `mv`(bool): save the intermediate results of the run and record the video (This increases the run time)
|
43 |
+
|
44 |
+
**More parameters in `./conf/x/style.yaml`, you can modify these parameters from the command line. For
|
45 |
+
example, append `x.vpsd.n_particle=4` to the end of the script.**
|
46 |
+
|
47 |
+
### VPSD
|
48 |
+
|
49 |
+
**Prompt:** Sydney opera house. oil painting. by Van Gogh <br/>
|
50 |
+
**Style:** iconography <br/>
|
51 |
+
**Preview:**
|
52 |
+
|
53 |
+
| Particle 1 | Particle 2 | Particle 3 | Particle 4 | Particle 5 | Particle 6 |
|
54 |
+
|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|
|
55 |
+
| init p1 | init p2 | init p3 | init p4 | init p5 | init p6 |
|
56 |
+
| <img src="./assets/Icon-SydneyOperaHouse/init_p0.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p1.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p2.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p3.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p4.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p5.svg"> |
|
57 |
+
| final p1 | final p2 | final p3 | final p4 | final p5 | final p6 |
|
58 |
+
| <img src="./assets/Icon-SydneyOperaHouse/p_0.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_1.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_2.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_3.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_4.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_5.svg"> |
|
59 |
+
|
60 |
+
**Script:**
|
61 |
+
|
62 |
+
```shell
|
63 |
+
python svgdreamer.py x=iconography "prompt='Sydney opera house. oil painting. by Van Gogh'" result_path='./logs/SydneyOperaHouse-OilPainting'
|
64 |
+
```
|
65 |
+
|
66 |
+
**Other Styles:**
|
67 |
+
|
68 |
+
```shell
|
69 |
+
# Style: low-ploy
|
70 |
+
python svgdreamer.py x=lowpoly "prompt='A picture of a bald eagle. low-ploy. polygon'" result_path='./logs/BaldEagle'
|
71 |
+
# Style: pixel-art
|
72 |
+
python svgdreamer.py x=pixelart "prompt='Darth vader with lightsaber.'" result_path='./log/DarthVader'
|
73 |
+
# Style: painting
|
74 |
+
python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" result_path='./logs/VanGogh-Portrait'
|
75 |
+
# Style: sketch
|
76 |
+
python svgdreamer.py x=sketch "prompt='A free-hand drawing of A speeding Lamborghini. black and white drawing.'" result_path='./logs/Lamborghini'
|
77 |
+
# Style: ink and wash
|
78 |
+
python svgdreamer.py x=ink "prompt='Big Wild Goose Pagoda. ink style. Minimalist abstract art grayscale watercolor.'" result_path='./logs/BigWildGoosePagoda'
|
79 |
+
```
|
80 |
+
|
81 |
+
## 🔑 Tips
|
82 |
+
|
83 |
+
- `x.vpsd.t_schedule` greatly affects the style of the result. Please try more.
|
84 |
+
- `neg_prompt` negative prompts affect the quality of the results.
|
85 |
+
|
86 |
+
## 📋 TODO
|
87 |
+
|
88 |
+
- [x] Release the code
|
89 |
|
90 |
## :books: Acknowledgement
|
91 |
|
|
|
94 |
- [BachiLi/diffvg](https://github.com/BachiLi/diffvg)
|
95 |
- [huggingface/diffusers](https://github.com/huggingface/diffusers)
|
96 |
- [ximinng/DiffSketcher](https://github.com/ximinng/DiffSketcher)
|
97 |
+
- [THUDM/ImageReward](https://github.com/THUDM/ImageReward)
|
98 |
+
- [ximinng//PyTorch-SVGRender](https://github.com/ximinng/PyTorch-SVGRender)
|
99 |
|
100 |
We gratefully thank the authors for their wonderful works.
|
101 |
|
|
|
105 |
|
106 |
```
|
107 |
@article{xing2023svgdreamer,
|
108 |
+
title={SVGDreamer: Text Guided SVG Generation with Diffusion Model},
|
109 |
+
author={Xing, Ximing and Zhou, Haitao and Wang, Chuang and Zhang, Jing and Xu, Dong and Yu, Qian},
|
110 |
+
journal={arXiv preprint arXiv:2312.16476},
|
111 |
+
year={2023}
|
112 |
}
|
113 |
```
|
114 |
|
assets/Icon-SydneyOperaHouse/init_p0.svg
ADDED
|
assets/Icon-SydneyOperaHouse/init_p1.svg
ADDED
|
assets/Icon-SydneyOperaHouse/init_p2.svg
ADDED
|
assets/Icon-SydneyOperaHouse/init_p3.svg
ADDED
|
assets/Icon-SydneyOperaHouse/init_p4.svg
ADDED
|
assets/Icon-SydneyOperaHouse/init_p5.svg
ADDED
|
assets/Icon-SydneyOperaHouse/p_0.svg
ADDED
|
assets/Icon-SydneyOperaHouse/p_1.svg
ADDED
|
assets/Icon-SydneyOperaHouse/p_2.svg
ADDED
|
assets/Icon-SydneyOperaHouse/p_3.svg
ADDED
|
assets/Icon-SydneyOperaHouse/p_4.svg
ADDED
|
assets/Icon-SydneyOperaHouse/p_5.svg
ADDED
|
assets/{teaser1.png → illustrate.png}
RENAMED
File without changes
|
assets/{teaser2.png → teaser_cases.png}
RENAMED
File without changes
|
assets/{teaser3.png → teaser_more_cases.png}
RENAMED
File without changes
|
assets/teaser_svg_asset.png
ADDED
![]() |
Git LFS Details
|
conf/config.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#-----------------#
|
2 |
+
# Global Config #
|
3 |
+
#-----------------#
|
4 |
+
|
5 |
+
# common args
|
6 |
+
prompt: ~
|
7 |
+
token_ind: 1 # the index of text prompt, from 1
|
8 |
+
neg_prompt: ~ # negative prompt
|
9 |
+
skip_sive: True # optimize from scratch without SIVE init
|
10 |
+
|
11 |
+
# Accelerate config
|
12 |
+
state:
|
13 |
+
cpu: False # use cpu
|
14 |
+
mprec: no # mixed precision, choices: 'no', 'fp16', 'bf16'
|
15 |
+
|
16 |
+
# Diffusers config
|
17 |
+
diffuser:
|
18 |
+
download: False # Set this variable to True the first time it runs
|
19 |
+
force_download: False
|
20 |
+
resume_download: False
|
21 |
+
|
22 |
+
# PyDiffVG config
|
23 |
+
diffvg:
|
24 |
+
print_timing: False
|
25 |
+
|
26 |
+
# reproduction
|
27 |
+
seed: 951222
|
28 |
+
# multi-run
|
29 |
+
multirun: False
|
30 |
+
srange: ~ # seed range, example: [100, 100]
|
31 |
+
|
32 |
+
# log
|
33 |
+
result_path: './workspace'
|
34 |
+
save_step: 50
|
35 |
+
|
36 |
+
# visual rendering process
|
37 |
+
mv: False # make video
|
38 |
+
framefreq: 5 # save the image interval
|
39 |
+
framerate: 24 # by adjusting the frame rate, you can control the playback speed of the output video
|
40 |
+
|
41 |
+
# hydra setting
|
42 |
+
hydra:
|
43 |
+
help:
|
44 |
+
# app name, override to match the name your app is known by
|
45 |
+
app_name: 'SVGDreamer'
|
46 |
+
run:
|
47 |
+
# output directory for normal runs
|
48 |
+
# warning: make sure that the L53-55 of './libs/model_state.py' and 'dir' are modified together
|
49 |
+
dir: ./${result_path}/SVGDreamer-${now:%Y-%m-%d-%H-%M}
|
50 |
+
|
51 |
+
# default settings
|
52 |
+
defaults:
|
53 |
+
- _self_
|
54 |
+
- x: ~
|
conf/x/iconography.yaml
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_size: 600 # canvas size
|
2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
3 |
+
color_init: 'rand' # if skip_live=True, then use color_init to init target_img
|
4 |
+
style: "iconography" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
|
5 |
+
|
6 |
+
# stable diffusion in SIVE stage
|
7 |
+
sive_model_cfg:
|
8 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
9 |
+
ldm_speed_up: False
|
10 |
+
enable_xformers: True
|
11 |
+
gradient_checkpoint: False
|
12 |
+
cpu_offload: True
|
13 |
+
num_inference_steps: 100
|
14 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
15 |
+
lora_path: ~
|
16 |
+
|
17 |
+
# lr and optim
|
18 |
+
sive_stage_optim:
|
19 |
+
point: 1 # control points
|
20 |
+
width: 0.1 # stroke width
|
21 |
+
color: 0.01 # fill color and stroke color
|
22 |
+
bg: 0.01 # bg in render_warp
|
23 |
+
optim:
|
24 |
+
name: 'adam'
|
25 |
+
betas: [ 0.9, 0.9 ]
|
26 |
+
eps: 1e-6
|
27 |
+
schedule:
|
28 |
+
name: 'linear'
|
29 |
+
keep_ratio: 0.2
|
30 |
+
decay_ratio: 0.4
|
31 |
+
|
32 |
+
# SIVE rendering
|
33 |
+
sive:
|
34 |
+
attn_cfg: # init content via attn
|
35 |
+
cross_attn_res: 16
|
36 |
+
self_attn_res: 32
|
37 |
+
max_com: 20
|
38 |
+
mean_comp: False
|
39 |
+
comp_idx: 0
|
40 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
bg:
|
42 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
+
num_iter: 10
|
44 |
+
num_paths: 256
|
45 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
46 |
+
schedule_each: 128
|
47 |
+
width: 3 # sketch stroke width
|
48 |
+
num_segments: 4
|
49 |
+
segment_init: 'circle' # 'random'
|
50 |
+
radius: 20
|
51 |
+
coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
|
52 |
+
grid: 20
|
53 |
+
# optim
|
54 |
+
lr_schedule: True
|
55 |
+
optim_bg: False # train background
|
56 |
+
use_attn_init: True
|
57 |
+
softmax_tau: 0.3 # temperature of softmax
|
58 |
+
# loss
|
59 |
+
use_distance_weighted_loss: False
|
60 |
+
xing_loss_weight: 0.001
|
61 |
+
fg:
|
62 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
63 |
+
num_iter: 10
|
64 |
+
num_paths: 256 # number of strokes
|
65 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
66 |
+
schedule_each: 128
|
67 |
+
width: 3 # sketch stroke width
|
68 |
+
num_segments: 4
|
69 |
+
segment_init: 'circle' # 'random'
|
70 |
+
radius: 15
|
71 |
+
coord_init: 'random' # 'random', 'naive', place the first control point
|
72 |
+
grid: 20
|
73 |
+
# optim
|
74 |
+
lr_schedule: False
|
75 |
+
optim_bg: False # train background
|
76 |
+
use_attn_init: True
|
77 |
+
softmax_tau: 0.3 # temperature of softmax
|
78 |
+
# loss
|
79 |
+
use_distance_weighted_loss: False
|
80 |
+
xing_loss_weight: 0.01
|
81 |
+
tog: # for refinement
|
82 |
+
reinit: True # if False, use fg params to init content
|
83 |
+
num_iter: 1000
|
84 |
+
# optim
|
85 |
+
lr_schedule: False # enable lr_scheduler or not
|
86 |
+
# loss
|
87 |
+
bg_lam: 0
|
88 |
+
fg_lam: 1
|
89 |
+
xing_loss_weight: 0
|
90 |
+
|
91 |
+
# VPSD primitives
|
92 |
+
num_paths: 512 # number of strokes
|
93 |
+
trainable_bg: False # set the background to be trainable
|
94 |
+
width: 3 # stroke width
|
95 |
+
num_segments: 4
|
96 |
+
segment_init: 'circle' # 'random'
|
97 |
+
radius: 20
|
98 |
+
coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
|
99 |
+
grid: 50 # divide the canvas into n grids
|
100 |
+
path_reinit: # reinitializing paths
|
101 |
+
use: True
|
102 |
+
freq: 100 # every 50 iterations
|
103 |
+
stop_step: 1000 # for VPSD fine-tuning
|
104 |
+
opacity_threshold: 0.05
|
105 |
+
area_threshold: 64
|
106 |
+
|
107 |
+
# lr and optim
|
108 |
+
vpsd_stage_optim:
|
109 |
+
point: 1
|
110 |
+
width: 0.1
|
111 |
+
color: 0.01
|
112 |
+
bg: 0.01
|
113 |
+
lr_schedule: True # use lr_scheduler
|
114 |
+
optim:
|
115 |
+
name: 'adam'
|
116 |
+
betas: [ 0.9, 0.9 ]
|
117 |
+
eps: 1e-6
|
118 |
+
schedule:
|
119 |
+
name: 'cosine'
|
120 |
+
warmup_steps: 10
|
121 |
+
warmup_start_lr: 0.02
|
122 |
+
warmup_end_lr: 0.9
|
123 |
+
cosine_end_lr: 0.4
|
124 |
+
|
125 |
+
# stable diffusion in VPSD stage
|
126 |
+
vpsd_model_cfg:
|
127 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
128 |
+
ldm_speed_up: False
|
129 |
+
enable_xformers: True
|
130 |
+
gradient_checkpoint: False
|
131 |
+
cpu_offload: True
|
132 |
+
num_inference_steps: 100
|
133 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
134 |
+
lora_path: ~
|
135 |
+
|
136 |
+
# VPSD setting
|
137 |
+
vpsd:
|
138 |
+
type: 'vpsd'
|
139 |
+
n_particle: 6 # 4, 8, 16
|
140 |
+
vsd_n_particle: 4 # the batch size of particles
|
141 |
+
particle_aug: False # do data enhancement for the input particles
|
142 |
+
num_iter: 2000 # total iterations
|
143 |
+
guidance_scale: 7.5 # CFG value
|
144 |
+
grad_scale: 1.0 # increase or decrease the gradient
|
145 |
+
grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
|
146 |
+
t_range: [ 0.02, 0.98 ]
|
147 |
+
# 'randint': random time steps, this may have a more authentic style.
|
148 |
+
# 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
|
149 |
+
t_schedule: 'max_0.5_1500' # or 'randint'
|
150 |
+
# phi model config
|
151 |
+
phi_single: False # if False new an unet model to estimate noise
|
152 |
+
phi_model: 'lora' # 'lora', 'unet_simple'
|
153 |
+
use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
|
154 |
+
lora_attn_scale: 1.0 # the scale of the attn based lora layer
|
155 |
+
phi_guidance_scale: 1.0
|
156 |
+
phi_t: False # different t for phi fine-tuning
|
157 |
+
phi_update_step: 1 # enable multi-update phi model or not
|
158 |
+
phi_lr: 0.0001 # learning rate of phi model
|
159 |
+
phi_scheduler: 'ddim'
|
160 |
+
phi_n_particle: 2 # the batch size of phi_model
|
161 |
+
# ReFL config
|
162 |
+
phi_ReFL: False # enable reward feed back learning
|
163 |
+
n_phi_sample: 1 # number of samples used in ReFL
|
164 |
+
phi_sample_step: 200 # the phi log step
|
165 |
+
phi_infer_step: 50 # the phi num_inference_steps
|
166 |
+
# phi model optim
|
167 |
+
phi_optim:
|
168 |
+
name: 'adamw'
|
169 |
+
betas: [ 0.9, 0.999 ]
|
170 |
+
eps: 1e-8
|
171 |
+
weight_decay: ~ # 1e-5
|
172 |
+
# phi model lr learning schedule
|
173 |
+
phi_schedule:
|
174 |
+
use: False
|
175 |
+
name: 'cosine'
|
176 |
+
warmup_steps: 50
|
177 |
+
warmup_start_lr: 0.00001
|
178 |
+
warmup_end_lr: 0.0001
|
179 |
+
total_step: 800
|
180 |
+
cosine_end_lr: 0.0001
|
181 |
+
|
182 |
+
# reward model
|
183 |
+
reward_path: './checkpoint/ImageReward'
|
184 |
+
|
185 |
+
# xing loss for closed-form paths
|
186 |
+
xing_loss:
|
187 |
+
use: False
|
188 |
+
weight: 0.01
|
conf/x/ink.yaml
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_size: 600 # canvas size
|
2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
3 |
+
color_init: 'rand' # if skip_live=True, then use color_init to init target_img
|
4 |
+
style: "ink" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
|
5 |
+
|
6 |
+
# stable diffusion in SIVE stage
|
7 |
+
sive_model_cfg:
|
8 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
9 |
+
ldm_speed_up: False
|
10 |
+
enable_xformers: True
|
11 |
+
gradient_checkpoint: False
|
12 |
+
cpu_offload: True
|
13 |
+
num_inference_steps: 100
|
14 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
15 |
+
lora_path: ~
|
16 |
+
|
17 |
+
# lr and optim
|
18 |
+
sive_stage_optim:
|
19 |
+
point: 1 # control points
|
20 |
+
width: 0.1 # stroke width
|
21 |
+
color: 0.01 # fill color and stroke color
|
22 |
+
bg: 0.01 # bg in render_warp
|
23 |
+
optim:
|
24 |
+
name: 'adam'
|
25 |
+
betas: [ 0.9, 0.9 ]
|
26 |
+
eps: 1e-6
|
27 |
+
schedule:
|
28 |
+
name: 'linear'
|
29 |
+
keep_ratio: 0.2
|
30 |
+
decay_ratio: 0.4
|
31 |
+
|
32 |
+
# SIVE rendering
|
33 |
+
sive:
|
34 |
+
attn_cfg: # init content via attn
|
35 |
+
cross_attn_res: 16
|
36 |
+
self_attn_res: 32
|
37 |
+
max_com: 20
|
38 |
+
mean_comp: False
|
39 |
+
comp_idx: 0
|
40 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
bg:
|
42 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
+
num_iter: 10
|
44 |
+
num_paths: 256
|
45 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
46 |
+
schedule_each: 128
|
47 |
+
width: 3 # sketch stroke width
|
48 |
+
num_segments: 4
|
49 |
+
segment_init: 'circle' # 'random'
|
50 |
+
radius: 20
|
51 |
+
coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
|
52 |
+
grid: 20
|
53 |
+
# optim
|
54 |
+
lr_schedule: True
|
55 |
+
optim_bg: False # train background
|
56 |
+
use_attn_init: True
|
57 |
+
softmax_tau: 0.3 # temperature of softmax
|
58 |
+
# loss
|
59 |
+
use_distance_weighted_loss: False
|
60 |
+
xing_loss_weight: 0.001
|
61 |
+
fg:
|
62 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
63 |
+
num_iter: 10
|
64 |
+
num_paths: 256 # number of strokes
|
65 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
66 |
+
schedule_each: 128
|
67 |
+
width: 3 # sketch stroke width
|
68 |
+
num_segments: 4
|
69 |
+
segment_init: 'circle' # 'random'
|
70 |
+
radius: 15
|
71 |
+
coord_init: 'random' # 'random', 'naive', place the first control point
|
72 |
+
grid: 20
|
73 |
+
# optim
|
74 |
+
lr_schedule: False
|
75 |
+
optim_bg: False # train background
|
76 |
+
use_attn_init: True
|
77 |
+
softmax_tau: 0.3 # temperature of softmax
|
78 |
+
# loss
|
79 |
+
use_distance_weighted_loss: False
|
80 |
+
xing_loss_weight: 0.01
|
81 |
+
tog: # for refinement
|
82 |
+
reinit: True # if False, use fg params to init content
|
83 |
+
num_iter: 1000
|
84 |
+
# optim
|
85 |
+
lr_schedule: False # enable lr_scheduler or not
|
86 |
+
# loss
|
87 |
+
bg_lam: 0
|
88 |
+
fg_lam: 1
|
89 |
+
xing_loss_weight: 0
|
90 |
+
|
91 |
+
# VPSD primitives
|
92 |
+
num_paths: 128 # number of strokes
|
93 |
+
trainable_bg: False # set the background to be trainable
|
94 |
+
width: 6 # stroke width
|
95 |
+
num_segments: 4
|
96 |
+
segment_init: 'circle' # 'random'
|
97 |
+
radius: 20
|
98 |
+
coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
|
99 |
+
grid: 50 # divide the canvas into n grids
|
100 |
+
path_reinit: # reinitializing paths
|
101 |
+
use: True
|
102 |
+
freq: 100 # every 50 iterations
|
103 |
+
stop_step: 1000 # for VPSD fine-tuning
|
104 |
+
opacity_threshold: 0.05
|
105 |
+
area_threshold: 64
|
106 |
+
|
107 |
+
# lr and optim
|
108 |
+
vpsd_stage_optim:
|
109 |
+
point: 1
|
110 |
+
width: 0.1
|
111 |
+
color: 0.01
|
112 |
+
bg: 0.01
|
113 |
+
lr_schedule: True # use lr_scheduler
|
114 |
+
optim:
|
115 |
+
name: 'adam'
|
116 |
+
betas: [ 0.9, 0.9 ]
|
117 |
+
eps: 1e-6
|
118 |
+
schedule:
|
119 |
+
name: 'cosine'
|
120 |
+
warmup_steps: 10
|
121 |
+
warmup_start_lr: 0.02
|
122 |
+
warmup_end_lr: 0.9
|
123 |
+
cosine_end_lr: 0.4
|
124 |
+
|
125 |
+
# stable diffusion in VPSD stage
|
126 |
+
vpsd_model_cfg:
|
127 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
128 |
+
ldm_speed_up: False
|
129 |
+
enable_xformers: True
|
130 |
+
gradient_checkpoint: False
|
131 |
+
cpu_offload: True
|
132 |
+
num_inference_steps: 100
|
133 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
134 |
+
lora_path: ~
|
135 |
+
|
136 |
+
# VPSD setting
|
137 |
+
vpsd:
|
138 |
+
type: 'vpsd'
|
139 |
+
n_particle: 6 # 4, 8, 16
|
140 |
+
vsd_n_particle: 4 # the batch size of particles
|
141 |
+
particle_aug: False # do data enhancement for the input particles
|
142 |
+
num_iter: 2000 # total iterations
|
143 |
+
guidance_scale: 7.5 # CFG value
|
144 |
+
grad_scale: 1.0 # increase or decrease the gradient
|
145 |
+
grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
|
146 |
+
t_range: [ 0.02, 0.98 ]
|
147 |
+
# 'randint': random time steps, this may have a more authentic style.
|
148 |
+
# 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
|
149 |
+
t_schedule: 'randint' # or 'randint'
|
150 |
+
# phi model config
|
151 |
+
phi_single: False # if False new an unet model to estimate noise
|
152 |
+
phi_model: 'lora' # 'lora', 'unet_simple'
|
153 |
+
use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
|
154 |
+
lora_attn_scale: 1.0 # the scale of the attn based lora layer
|
155 |
+
phi_guidance_scale: 1.0
|
156 |
+
phi_t: False # different t for phi fine-tuning
|
157 |
+
phi_update_step: 1 # enable multi-update phi model or not
|
158 |
+
phi_lr: 0.0001 # learning rate of phi model
|
159 |
+
phi_scheduler: 'ddim'
|
160 |
+
phi_n_particle: 2 # the batch size of phi_model
|
161 |
+
# ReFL config
|
162 |
+
phi_ReFL: False # enable reward feed back learning
|
163 |
+
n_phi_sample: 1 # number of samples used in ReFL
|
164 |
+
phi_sample_step: 200 # the phi log step
|
165 |
+
phi_infer_step: 50 # the phi num_inference_steps
|
166 |
+
# phi model optim
|
167 |
+
phi_optim:
|
168 |
+
name: 'adamw'
|
169 |
+
betas: [ 0.9, 0.999 ]
|
170 |
+
eps: 1e-8
|
171 |
+
weight_decay: ~ # 1e-5
|
172 |
+
# phi model lr learning schedule
|
173 |
+
phi_schedule:
|
174 |
+
use: False
|
175 |
+
name: 'cosine'
|
176 |
+
warmup_steps: 50
|
177 |
+
warmup_start_lr: 0.00001
|
178 |
+
warmup_end_lr: 0.0001
|
179 |
+
total_step: 800
|
180 |
+
cosine_end_lr: 0.0001
|
181 |
+
|
182 |
+
# reward model
|
183 |
+
reward_path: './checkpoint/ImageReward'
|
184 |
+
|
185 |
+
# xing loss for closed-form paths
|
186 |
+
xing_loss:
|
187 |
+
use: False
|
188 |
+
weight: 0.01
|
conf/x/lowpoly.yaml
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_size: 600 # canvas size
|
2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
3 |
+
color_init: 'rand' # if skip_live=True, then use color_init to init target_img
|
4 |
+
style: "low-poly" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
|
5 |
+
|
6 |
+
# stable diffusion in SIVE stage
|
7 |
+
sive_model_cfg:
|
8 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
9 |
+
ldm_speed_up: False
|
10 |
+
enable_xformers: True
|
11 |
+
gradient_checkpoint: False
|
12 |
+
cpu_offload: True
|
13 |
+
num_inference_steps: 100
|
14 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
15 |
+
lora_path: ~
|
16 |
+
|
17 |
+
# lr and optim
|
18 |
+
sive_stage_optim:
|
19 |
+
point: 1 # control points
|
20 |
+
width: 0.1 # stroke width
|
21 |
+
color: 0.01 # fill color and stroke color
|
22 |
+
bg: 0.01 # bg in render_warp
|
23 |
+
optim:
|
24 |
+
name: 'adam'
|
25 |
+
betas: [ 0.9, 0.9 ]
|
26 |
+
eps: 1e-6
|
27 |
+
schedule:
|
28 |
+
name: 'linear'
|
29 |
+
keep_ratio: 0.2
|
30 |
+
decay_ratio: 0.4
|
31 |
+
|
32 |
+
# SIVE rendering
|
33 |
+
sive:
|
34 |
+
attn_cfg: # init content via attn
|
35 |
+
cross_attn_res: 16
|
36 |
+
self_attn_res: 32
|
37 |
+
max_com: 20
|
38 |
+
mean_comp: False
|
39 |
+
comp_idx: 0
|
40 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
bg:
|
42 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
+
num_iter: 10
|
44 |
+
num_paths: 256
|
45 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
46 |
+
schedule_each: 128
|
47 |
+
width: 3 # sketch stroke width
|
48 |
+
num_segments: 4
|
49 |
+
segment_init: 'circle' # 'random'
|
50 |
+
radius: 20
|
51 |
+
coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
|
52 |
+
grid: 20
|
53 |
+
# optim
|
54 |
+
lr_schedule: True
|
55 |
+
optim_bg: False # train background
|
56 |
+
use_attn_init: True
|
57 |
+
softmax_tau: 0.3 # temperature of softmax
|
58 |
+
# loss
|
59 |
+
use_distance_weighted_loss: False
|
60 |
+
xing_loss_weight: 0.001
|
61 |
+
fg:
|
62 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
63 |
+
num_iter: 10
|
64 |
+
num_paths: 256 # number of strokes
|
65 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
66 |
+
schedule_each: 128
|
67 |
+
width: 3 # sketch stroke width
|
68 |
+
num_segments: 4
|
69 |
+
segment_init: 'circle' # 'random'
|
70 |
+
radius: 15
|
71 |
+
coord_init: 'random' # 'random', 'naive', place the first control point
|
72 |
+
grid: 20
|
73 |
+
# optim
|
74 |
+
lr_schedule: False
|
75 |
+
optim_bg: False # train background
|
76 |
+
use_attn_init: True
|
77 |
+
softmax_tau: 0.3 # temperature of softmax
|
78 |
+
# loss
|
79 |
+
use_distance_weighted_loss: False
|
80 |
+
xing_loss_weight: 0.01
|
81 |
+
tog: # for refinement
|
82 |
+
reinit: True # if False, use fg params to init content
|
83 |
+
num_iter: 1000
|
84 |
+
# optim
|
85 |
+
lr_schedule: False # enable lr_scheduler or not
|
86 |
+
# loss
|
87 |
+
bg_lam: 0
|
88 |
+
fg_lam: 1
|
89 |
+
xing_loss_weight: 0
|
90 |
+
|
91 |
+
# VPSD primitives
|
92 |
+
num_paths: 512 # number of strokes
|
93 |
+
trainable_bg: False # set the background to be trainable
|
94 |
+
width: 3 # stroke width
|
95 |
+
num_segments: 4
|
96 |
+
segment_init: 'circle' # 'random'
|
97 |
+
radius: 20
|
98 |
+
coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
|
99 |
+
grid: 30 # divide the canvas into n grids
|
100 |
+
path_reinit: # reinitializing paths
|
101 |
+
use: True
|
102 |
+
freq: 100 # every 50 iterations
|
103 |
+
stop_step: 1000 # for VPSD fine-tuning
|
104 |
+
opacity_threshold: 0.05
|
105 |
+
area_threshold: 64
|
106 |
+
|
107 |
+
# lr and optim
|
108 |
+
vpsd_stage_optim:
|
109 |
+
point: 1
|
110 |
+
width: 0.1
|
111 |
+
color: 0.01
|
112 |
+
bg: 0.01
|
113 |
+
lr_schedule: True # use lr_scheduler
|
114 |
+
optim:
|
115 |
+
name: 'adam'
|
116 |
+
betas: [ 0.9, 0.9 ]
|
117 |
+
eps: 1e-6
|
118 |
+
schedule:
|
119 |
+
name: 'cosine'
|
120 |
+
warmup_steps: 10
|
121 |
+
warmup_start_lr: 0.02
|
122 |
+
warmup_end_lr: 0.9
|
123 |
+
cosine_end_lr: 0.4
|
124 |
+
|
125 |
+
# stable diffusion in VPSD stage
|
126 |
+
vpsd_model_cfg:
|
127 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
128 |
+
ldm_speed_up: False
|
129 |
+
enable_xformers: True
|
130 |
+
gradient_checkpoint: False
|
131 |
+
cpu_offload: True
|
132 |
+
num_inference_steps: 100
|
133 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
134 |
+
lora_path: ~
|
135 |
+
|
136 |
+
# VPSD setting
|
137 |
+
vpsd:
|
138 |
+
type: 'vpsd'
|
139 |
+
n_particle: 6 # 4, 8, 16
|
140 |
+
vsd_n_particle: 4 # the batch size of particles
|
141 |
+
particle_aug: False # do data enhancement for the input particles
|
142 |
+
num_iter: 1500 # total iterations
|
143 |
+
guidance_scale: 7.5 # CFG value
|
144 |
+
grad_scale: 1.0 # increase or decrease the gradient
|
145 |
+
grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
|
146 |
+
t_range: [ 0.02, 0.98 ]
|
147 |
+
# 'randint': random time steps, this may have a more authentic style.
|
148 |
+
# 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
|
149 |
+
t_schedule: 'max_0.5_1500' # or 'randint'
|
150 |
+
# phi model config
|
151 |
+
phi_single: False # if False new an unet model to estimate noise
|
152 |
+
phi_model: 'lora' # 'lora', 'unet_simple'
|
153 |
+
use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
|
154 |
+
lora_attn_scale: 1.0 # the scale of the attn based lora layer
|
155 |
+
phi_guidance_scale: 1.0
|
156 |
+
phi_t: False # different t for phi fine-tuning
|
157 |
+
phi_update_step: 1 # enable multi-update phi model or not
|
158 |
+
phi_lr: 0.0001 # learning rate of phi model
|
159 |
+
phi_scheduler: 'ddim'
|
160 |
+
phi_n_particle: 2 # the batch size of phi_model
|
161 |
+
# ReFL config
|
162 |
+
phi_ReFL: False # enable reward feed back learning
|
163 |
+
n_phi_sample: 1 # number of samples used in ReFL
|
164 |
+
phi_sample_step: 200 # the phi log step
|
165 |
+
phi_infer_step: 50 # the phi num_inference_steps
|
166 |
+
# phi model optim
|
167 |
+
phi_optim:
|
168 |
+
name: 'adamw'
|
169 |
+
betas: [ 0.9, 0.999 ]
|
170 |
+
eps: 1e-8
|
171 |
+
weight_decay: ~ # 1e-5
|
172 |
+
# phi model lr learning schedule
|
173 |
+
phi_schedule:
|
174 |
+
use: False
|
175 |
+
name: 'cosine'
|
176 |
+
warmup_steps: 50
|
177 |
+
warmup_start_lr: 0.00001
|
178 |
+
warmup_end_lr: 0.0001
|
179 |
+
total_step: 800
|
180 |
+
cosine_end_lr: 0.0001
|
181 |
+
|
182 |
+
# reward model
|
183 |
+
reward_path: './checkpoint/ImageReward'
|
184 |
+
|
185 |
+
# xing loss for closed-form paths
|
186 |
+
xing_loss:
|
187 |
+
use: False
|
188 |
+
weight: 0.01
|
conf/x/painting.yaml
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_size: 600 # canvas size
|
2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
3 |
+
color_init: 'rand' # if skip_live=True, then use color_init to init target_img
|
4 |
+
style: "painting" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
|
5 |
+
|
6 |
+
# stable diffusion in SIVE stage
|
7 |
+
sive_model_cfg:
|
8 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
9 |
+
ldm_speed_up: False
|
10 |
+
enable_xformers: True
|
11 |
+
gradient_checkpoint: False
|
12 |
+
cpu_offload: True
|
13 |
+
num_inference_steps: 100
|
14 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
15 |
+
lora_path: ~
|
16 |
+
|
17 |
+
# lr and optim
|
18 |
+
sive_stage_optim:
|
19 |
+
point: 1 # control points
|
20 |
+
width: 0.1 # stroke width
|
21 |
+
color: 0.01 # fill color and stroke color
|
22 |
+
bg: 0.01 # bg in render_warp
|
23 |
+
optim:
|
24 |
+
name: 'adam'
|
25 |
+
betas: [ 0.9, 0.9 ]
|
26 |
+
eps: 1e-6
|
27 |
+
schedule:
|
28 |
+
name: 'linear'
|
29 |
+
keep_ratio: 0.2
|
30 |
+
decay_ratio: 0.4
|
31 |
+
|
32 |
+
# SIVE rendering
|
33 |
+
sive:
|
34 |
+
attn_cfg: # init content via attn
|
35 |
+
cross_attn_res: 16
|
36 |
+
self_attn_res: 32
|
37 |
+
max_com: 20
|
38 |
+
mean_comp: False
|
39 |
+
comp_idx: 0
|
40 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
bg:
|
42 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
+
num_iter: 10
|
44 |
+
num_paths: 256
|
45 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
46 |
+
schedule_each: 128
|
47 |
+
width: 3 # sketch stroke width
|
48 |
+
num_segments: 4
|
49 |
+
segment_init: 'circle' # 'random'
|
50 |
+
radius: 20
|
51 |
+
coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
|
52 |
+
grid: 20
|
53 |
+
# optim
|
54 |
+
lr_schedule: True
|
55 |
+
optim_bg: False # train background
|
56 |
+
use_attn_init: True
|
57 |
+
softmax_tau: 0.3 # temperature of softmax
|
58 |
+
# loss
|
59 |
+
use_distance_weighted_loss: False
|
60 |
+
xing_loss_weight: 0.001
|
61 |
+
fg:
|
62 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
63 |
+
num_iter: 10
|
64 |
+
num_paths: 256 # number of strokes
|
65 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
66 |
+
schedule_each: 128
|
67 |
+
width: 3 # sketch stroke width
|
68 |
+
num_segments: 4
|
69 |
+
segment_init: 'circle' # 'random'
|
70 |
+
radius: 15
|
71 |
+
coord_init: 'random' # 'random', 'naive', place the first control point
|
72 |
+
grid: 20
|
73 |
+
# optim
|
74 |
+
lr_schedule: False
|
75 |
+
optim_bg: False # train background
|
76 |
+
use_attn_init: True
|
77 |
+
softmax_tau: 0.3 # temperature of softmax
|
78 |
+
# loss
|
79 |
+
use_distance_weighted_loss: False
|
80 |
+
xing_loss_weight: 0.01
|
81 |
+
tog: # for refinement
|
82 |
+
reinit: True # if False, use fg params to init content
|
83 |
+
num_iter: 1000
|
84 |
+
# optim
|
85 |
+
lr_schedule: False # enable lr_scheduler or not
|
86 |
+
# loss
|
87 |
+
bg_lam: 0
|
88 |
+
fg_lam: 1
|
89 |
+
xing_loss_weight: 0
|
90 |
+
|
91 |
+
# VPSD primitives
|
92 |
+
num_paths: 1500 # number of strokes
|
93 |
+
trainable_bg: False # set the background to be trainable
|
94 |
+
width: 3 # stroke width
|
95 |
+
num_segments: 4
|
96 |
+
segment_init: 'circle' # 'random'
|
97 |
+
radius: 20
|
98 |
+
coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
|
99 |
+
grid: 50 # divide the canvas into n grids
|
100 |
+
path_reinit: # reinitializing paths
|
101 |
+
use: True
|
102 |
+
freq: 100 # every 50 iterations
|
103 |
+
stop_step: 1000 # for VPSD fine-tuning
|
104 |
+
opacity_threshold: 0.05
|
105 |
+
area_threshold: 64
|
106 |
+
|
107 |
+
# lr and optim
|
108 |
+
vpsd_stage_optim:
|
109 |
+
point: 1
|
110 |
+
width: 0.1
|
111 |
+
color: 0.01
|
112 |
+
bg: 0.01
|
113 |
+
lr_schedule: True # use lr_scheduler
|
114 |
+
optim:
|
115 |
+
name: 'adam'
|
116 |
+
betas: [ 0.9, 0.9 ]
|
117 |
+
eps: 1e-6
|
118 |
+
schedule:
|
119 |
+
name: 'cosine'
|
120 |
+
warmup_steps: 10
|
121 |
+
warmup_start_lr: 0.02
|
122 |
+
warmup_end_lr: 0.9
|
123 |
+
cosine_end_lr: 0.4
|
124 |
+
|
125 |
+
# stable diffusion in VPSD stage
|
126 |
+
vpsd_model_cfg:
|
127 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
128 |
+
ldm_speed_up: False
|
129 |
+
enable_xformers: True
|
130 |
+
gradient_checkpoint: False
|
131 |
+
cpu_offload: True
|
132 |
+
num_inference_steps: 100
|
133 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
134 |
+
lora_path: ~
|
135 |
+
|
136 |
+
# VPSD setting
|
137 |
+
vpsd:
|
138 |
+
type: 'vpsd'
|
139 |
+
n_particle: 6 # 4, 8, 16
|
140 |
+
vsd_n_particle: 4 # the batch size of particles
|
141 |
+
particle_aug: False # do data enhancement for the input particles
|
142 |
+
num_iter: 2000 # total iterations
|
143 |
+
guidance_scale: 7.5 # CFG value
|
144 |
+
grad_scale: 1.0 # increase or decrease the gradient
|
145 |
+
grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
|
146 |
+
t_range: [ 0.02, 0.98 ]
|
147 |
+
# 'randint': random time steps, this may have a more authentic style.
|
148 |
+
# 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
|
149 |
+
t_schedule: 'randint' # or 'randint'
|
150 |
+
# phi model config
|
151 |
+
phi_single: False # if False new an unet model to estimate noise
|
152 |
+
phi_model: 'lora' # 'lora', 'unet_simple'
|
153 |
+
use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
|
154 |
+
lora_attn_scale: 1.0 # the scale of the attn based lora layer
|
155 |
+
phi_guidance_scale: 1.0
|
156 |
+
phi_t: False # different t for phi fine-tuning
|
157 |
+
phi_update_step: 1 # enable multi-update phi model or not
|
158 |
+
phi_lr: 0.0001 # learning rate of phi model
|
159 |
+
phi_scheduler: 'ddim'
|
160 |
+
phi_n_particle: 2 # the batch size of phi_model
|
161 |
+
# ReFL config
|
162 |
+
phi_ReFL: False # enable reward feed back learning
|
163 |
+
n_phi_sample: 1 # number of samples used in ReFL
|
164 |
+
phi_sample_step: 200 # the phi log step
|
165 |
+
phi_infer_step: 50 # the phi num_inference_steps
|
166 |
+
# phi model optim
|
167 |
+
phi_optim:
|
168 |
+
name: 'adamw'
|
169 |
+
betas: [ 0.9, 0.999 ]
|
170 |
+
eps: 1e-8
|
171 |
+
weight_decay: ~ # 1e-5
|
172 |
+
# phi model lr learning schedule
|
173 |
+
phi_schedule:
|
174 |
+
use: False
|
175 |
+
name: 'cosine'
|
176 |
+
warmup_steps: 50
|
177 |
+
warmup_start_lr: 0.00001
|
178 |
+
warmup_end_lr: 0.0001
|
179 |
+
total_step: 800
|
180 |
+
cosine_end_lr: 0.0001
|
181 |
+
|
182 |
+
# reward model
|
183 |
+
reward_path: './checkpoint/ImageReward'
|
184 |
+
|
185 |
+
# xing loss for closed-form paths
|
186 |
+
xing_loss:
|
187 |
+
use: False
|
188 |
+
weight: 0.01
|
conf/x/pixelart.yaml
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_size: 600 # canvas size
|
2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
3 |
+
color_init: 'rand' # if skip_live=True, then use color_init to init target_img
|
4 |
+
style: "pixelart" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
|
5 |
+
|
6 |
+
# stable diffusion in SIVE stage
|
7 |
+
sive_model_cfg:
|
8 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
9 |
+
ldm_speed_up: False
|
10 |
+
enable_xformers: True
|
11 |
+
gradient_checkpoint: False
|
12 |
+
cpu_offload: True
|
13 |
+
num_inference_steps: 100
|
14 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
15 |
+
lora_path: ~
|
16 |
+
|
17 |
+
# lr and optim
|
18 |
+
sive_stage_optim:
|
19 |
+
point: 1 # control points
|
20 |
+
width: 0.1 # stroke width
|
21 |
+
color: 0.01 # fill color and stroke color
|
22 |
+
bg: 0.01 # bg in render_warp
|
23 |
+
optim:
|
24 |
+
name: 'adam'
|
25 |
+
betas: [ 0.9, 0.9 ]
|
26 |
+
eps: 1e-6
|
27 |
+
schedule:
|
28 |
+
name: 'linear'
|
29 |
+
keep_ratio: 0.2
|
30 |
+
decay_ratio: 0.4
|
31 |
+
|
32 |
+
# SIVE rendering
|
33 |
+
sive:
|
34 |
+
attn_cfg: # init content via attn
|
35 |
+
cross_attn_res: 16
|
36 |
+
self_attn_res: 32
|
37 |
+
max_com: 20
|
38 |
+
mean_comp: False
|
39 |
+
comp_idx: 0
|
40 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
bg:
|
42 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
+
num_iter: 10
|
44 |
+
num_paths: 256
|
45 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
46 |
+
schedule_each: 128
|
47 |
+
width: 3 # sketch stroke width
|
48 |
+
num_segments: 4
|
49 |
+
segment_init: 'circle' # 'random'
|
50 |
+
radius: 20
|
51 |
+
coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
|
52 |
+
grid: 20
|
53 |
+
# optim
|
54 |
+
lr_schedule: True
|
55 |
+
optim_bg: False # train background
|
56 |
+
use_attn_init: True
|
57 |
+
softmax_tau: 0.3 # temperature of softmax
|
58 |
+
# loss
|
59 |
+
use_distance_weighted_loss: False
|
60 |
+
xing_loss_weight: 0.001
|
61 |
+
fg:
|
62 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
63 |
+
num_iter: 10
|
64 |
+
num_paths: 256 # number of strokes
|
65 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
66 |
+
schedule_each: 128
|
67 |
+
width: 3 # sketch stroke width
|
68 |
+
num_segments: 4
|
69 |
+
segment_init: 'circle' # 'random'
|
70 |
+
radius: 15
|
71 |
+
coord_init: 'random' # 'random', 'naive', place the first control point
|
72 |
+
grid: 20
|
73 |
+
# optim
|
74 |
+
lr_schedule: False
|
75 |
+
optim_bg: False # train background
|
76 |
+
use_attn_init: True
|
77 |
+
softmax_tau: 0.3 # temperature of softmax
|
78 |
+
# loss
|
79 |
+
use_distance_weighted_loss: False
|
80 |
+
xing_loss_weight: 0.01
|
81 |
+
tog: # for refinement
|
82 |
+
reinit: True # if False, use fg params to init content
|
83 |
+
num_iter: 1000
|
84 |
+
# optim
|
85 |
+
lr_schedule: False # enable lr_scheduler or not
|
86 |
+
# loss
|
87 |
+
bg_lam: 0
|
88 |
+
fg_lam: 1
|
89 |
+
xing_loss_weight: 0
|
90 |
+
|
91 |
+
# VPSD primitives
|
92 |
+
num_paths: 512 # number of strokes
|
93 |
+
trainable_bg: False # set the background to be trainable
|
94 |
+
width: 3 # stroke width
|
95 |
+
num_segments: 4
|
96 |
+
segment_init: 'circle' # 'random'
|
97 |
+
radius: 20
|
98 |
+
coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
|
99 |
+
grid: 50 # divide the canvas into n grids
|
100 |
+
path_reinit: # reinitializing paths
|
101 |
+
use: True
|
102 |
+
freq: 100 # every 50 iterations
|
103 |
+
stop_step: 1000 # for VPSD fine-tuning
|
104 |
+
opacity_threshold: 0.05
|
105 |
+
area_threshold: 64
|
106 |
+
|
107 |
+
# lr and optim
|
108 |
+
vpsd_stage_optim:
|
109 |
+
point: 1
|
110 |
+
width: 0.1
|
111 |
+
color: 0.01
|
112 |
+
bg: 0.01
|
113 |
+
lr_schedule: True # use lr_scheduler
|
114 |
+
optim:
|
115 |
+
name: 'adam'
|
116 |
+
betas: [ 0.9, 0.9 ]
|
117 |
+
eps: 1e-6
|
118 |
+
schedule:
|
119 |
+
name: 'cosine'
|
120 |
+
warmup_steps: 10
|
121 |
+
warmup_start_lr: 0.02
|
122 |
+
warmup_end_lr: 0.9
|
123 |
+
cosine_end_lr: 0.4
|
124 |
+
|
125 |
+
# stable diffusion in VPSD stage
|
126 |
+
vpsd_model_cfg:
|
127 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
128 |
+
ldm_speed_up: False
|
129 |
+
enable_xformers: True
|
130 |
+
gradient_checkpoint: False
|
131 |
+
cpu_offload: True
|
132 |
+
num_inference_steps: 100
|
133 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
134 |
+
lora_path: ~
|
135 |
+
|
136 |
+
# VPSD setting
|
137 |
+
vpsd:
|
138 |
+
type: 'vpsd'
|
139 |
+
n_particle: 6 # 4, 8, 16
|
140 |
+
vsd_n_particle: 4 # the batch size of particles
|
141 |
+
particle_aug: False # do data enhancement for the input particles
|
142 |
+
num_iter: 1000 # total iterations
|
143 |
+
guidance_scale: 7.5 # CFG value
|
144 |
+
grad_scale: 1.0 # increase or decrease the gradient
|
145 |
+
grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
|
146 |
+
t_range: [ 0.02, 0.98 ]
|
147 |
+
# 'randint': random time steps, this may have a more authentic style.
|
148 |
+
# 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
|
149 |
+
t_schedule: 'max_0.5_1500' # or 'randint'
|
150 |
+
# phi model config
|
151 |
+
phi_single: False # if False new an unet model to estimate noise
|
152 |
+
phi_model: 'lora' # 'lora', 'unet_simple'
|
153 |
+
use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
|
154 |
+
lora_attn_scale: 1.0 # the scale of the attn based lora layer
|
155 |
+
phi_guidance_scale: 1.0
|
156 |
+
phi_t: False # different t for phi fine-tuning
|
157 |
+
phi_update_step: 1 # enable multi-update phi model or not
|
158 |
+
phi_lr: 0.0001 # learning rate of phi model
|
159 |
+
phi_scheduler: 'ddim'
|
160 |
+
phi_n_particle: 2 # the batch size of phi_model
|
161 |
+
# ReFL config
|
162 |
+
phi_ReFL: False # enable reward feed back learning
|
163 |
+
n_phi_sample: 1 # number of samples used in ReFL
|
164 |
+
phi_sample_step: 200 # the phi log step
|
165 |
+
phi_infer_step: 50 # the phi num_inference_steps
|
166 |
+
# phi model optim
|
167 |
+
phi_optim:
|
168 |
+
name: 'adamw'
|
169 |
+
betas: [ 0.9, 0.999 ]
|
170 |
+
eps: 1e-8
|
171 |
+
weight_decay: ~ # 1e-5
|
172 |
+
# phi model lr learning schedule
|
173 |
+
phi_schedule:
|
174 |
+
use: False
|
175 |
+
name: 'cosine'
|
176 |
+
warmup_steps: 50
|
177 |
+
warmup_start_lr: 0.00001
|
178 |
+
warmup_end_lr: 0.0001
|
179 |
+
total_step: 800
|
180 |
+
cosine_end_lr: 0.0001
|
181 |
+
|
182 |
+
# reward model
|
183 |
+
reward_path: './checkpoint/ImageReward'
|
184 |
+
|
185 |
+
# xing loss for closed-form paths
|
186 |
+
xing_loss:
|
187 |
+
use: False
|
188 |
+
weight: 0.01
|
conf/x/sketch.yaml
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_size: 600 # canvas size
|
2 |
+
path_svg: ~ # if you want to load a svg file and train from it
|
3 |
+
color_init: 'rand' # if skip_live=True, then use color_init to init target_img
|
4 |
+
style: "sketch" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
|
5 |
+
|
6 |
+
# stable diffusion in SIVE stage
|
7 |
+
sive_model_cfg:
|
8 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
9 |
+
ldm_speed_up: False
|
10 |
+
enable_xformers: True
|
11 |
+
gradient_checkpoint: False
|
12 |
+
cpu_offload: True
|
13 |
+
num_inference_steps: 100
|
14 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
15 |
+
lora_path: ~
|
16 |
+
|
17 |
+
# lr and optim
|
18 |
+
sive_stage_optim:
|
19 |
+
point: 1 # control points
|
20 |
+
width: 0.1 # stroke width
|
21 |
+
color: 0.01 # fill color and stroke color
|
22 |
+
bg: 0.01 # bg in render_warp
|
23 |
+
optim:
|
24 |
+
name: 'adam'
|
25 |
+
betas: [ 0.9, 0.9 ]
|
26 |
+
eps: 1e-6
|
27 |
+
schedule:
|
28 |
+
name: 'linear'
|
29 |
+
keep_ratio: 0.2
|
30 |
+
decay_ratio: 0.4
|
31 |
+
|
32 |
+
# SIVE rendering
|
33 |
+
sive:
|
34 |
+
attn_cfg: # init content via attn
|
35 |
+
cross_attn_res: 16
|
36 |
+
self_attn_res: 32
|
37 |
+
max_com: 20
|
38 |
+
mean_comp: False
|
39 |
+
comp_idx: 0
|
40 |
+
attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
|
41 |
+
bg:
|
42 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
43 |
+
num_iter: 10
|
44 |
+
num_paths: 256
|
45 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
46 |
+
schedule_each: 128
|
47 |
+
width: 3 # sketch stroke width
|
48 |
+
num_segments: 4
|
49 |
+
segment_init: 'circle' # 'random'
|
50 |
+
radius: 20
|
51 |
+
coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
|
52 |
+
grid: 20
|
53 |
+
# optim
|
54 |
+
lr_schedule: True
|
55 |
+
optim_bg: False # train background
|
56 |
+
use_attn_init: True
|
57 |
+
softmax_tau: 0.3 # temperature of softmax
|
58 |
+
# loss
|
59 |
+
use_distance_weighted_loss: False
|
60 |
+
xing_loss_weight: 0.001
|
61 |
+
fg:
|
62 |
+
style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
|
63 |
+
num_iter: 10
|
64 |
+
num_paths: 256 # number of strokes
|
65 |
+
path_schedule: 'repeat' # 'repeat', 'list'
|
66 |
+
schedule_each: 128
|
67 |
+
width: 3 # sketch stroke width
|
68 |
+
num_segments: 4
|
69 |
+
segment_init: 'circle' # 'random'
|
70 |
+
radius: 15
|
71 |
+
coord_init: 'random' # 'random', 'naive', place the first control point
|
72 |
+
grid: 20
|
73 |
+
# optim
|
74 |
+
lr_schedule: False
|
75 |
+
optim_bg: False # train background
|
76 |
+
use_attn_init: True
|
77 |
+
softmax_tau: 0.3 # temperature of softmax
|
78 |
+
# loss
|
79 |
+
use_distance_weighted_loss: False
|
80 |
+
xing_loss_weight: 0.01
|
81 |
+
tog: # for refinement
|
82 |
+
reinit: True # if False, use fg params to init content
|
83 |
+
num_iter: 1000
|
84 |
+
# optim
|
85 |
+
lr_schedule: False # enable lr_scheduler or not
|
86 |
+
# loss
|
87 |
+
bg_lam: 0
|
88 |
+
fg_lam: 1
|
89 |
+
xing_loss_weight: 0
|
90 |
+
|
91 |
+
# VPSD primitives
|
92 |
+
num_paths: 128 # number of strokes
|
93 |
+
trainable_bg: False # set the background to be trainable
|
94 |
+
width: 3 # stroke width
|
95 |
+
num_segments: 4
|
96 |
+
segment_init: 'circle' # 'random'
|
97 |
+
radius: 20
|
98 |
+
coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
|
99 |
+
grid: 50 # divide the canvas into n grids
|
100 |
+
path_reinit: # reinitializing paths
|
101 |
+
use: True
|
102 |
+
freq: 100 # every 50 iterations
|
103 |
+
stop_step: 1000 # for VPSD fine-tuning
|
104 |
+
opacity_threshold: 0.05
|
105 |
+
area_threshold: 64
|
106 |
+
|
107 |
+
# lr and optim
|
108 |
+
vpsd_stage_optim:
|
109 |
+
point: 1
|
110 |
+
width: 0.1
|
111 |
+
color: 0.01
|
112 |
+
bg: 0.01
|
113 |
+
lr_schedule: True # use lr_scheduler
|
114 |
+
optim:
|
115 |
+
name: 'adam'
|
116 |
+
betas: [ 0.9, 0.9 ]
|
117 |
+
eps: 1e-6
|
118 |
+
schedule:
|
119 |
+
name: 'cosine'
|
120 |
+
warmup_steps: 10
|
121 |
+
warmup_start_lr: 0.02
|
122 |
+
warmup_end_lr: 0.9
|
123 |
+
cosine_end_lr: 0.4
|
124 |
+
|
125 |
+
# stable diffusion in VPSD stage
|
126 |
+
vpsd_model_cfg:
|
127 |
+
model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
|
128 |
+
ldm_speed_up: False
|
129 |
+
enable_xformers: True
|
130 |
+
gradient_checkpoint: False
|
131 |
+
cpu_offload: True
|
132 |
+
num_inference_steps: 100
|
133 |
+
guidance_scale: 7.5 # sdxl default 5.0
|
134 |
+
lora_path: ~
|
135 |
+
|
136 |
+
# VPSD setting
|
137 |
+
vpsd:
|
138 |
+
type: 'vpsd'
|
139 |
+
n_particle: 6 # 4, 8, 16
|
140 |
+
vsd_n_particle: 4 # the batch size of particles
|
141 |
+
particle_aug: False # do data enhancement for the input particles
|
142 |
+
num_iter: 2000 # total iterations
|
143 |
+
guidance_scale: 7.5 # CFG value
|
144 |
+
grad_scale: 1.0 # increase or decrease the gradient
|
145 |
+
grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
|
146 |
+
t_range: [ 0.02, 0.98 ]
|
147 |
+
# 'randint': random time steps, this may have a more authentic style.
|
148 |
+
# 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
|
149 |
+
t_schedule: 'randint' # or 'randint'
|
150 |
+
# phi model config
|
151 |
+
phi_single: False # if False new an unet model to estimate noise
|
152 |
+
phi_model: 'lora' # 'lora', 'unet_simple'
|
153 |
+
use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
|
154 |
+
lora_attn_scale: 1.0 # the scale of the attn based lora layer
|
155 |
+
phi_guidance_scale: 1.0
|
156 |
+
phi_t: False # different t for phi fine-tuning
|
157 |
+
phi_update_step: 1 # enable multi-update phi model or not
|
158 |
+
phi_lr: 0.0001 # learning rate of phi model
|
159 |
+
phi_scheduler: 'ddim'
|
160 |
+
phi_n_particle: 2 # the batch size of phi_model
|
161 |
+
# ReFL config
|
162 |
+
phi_ReFL: False # enable reward feed back learning
|
163 |
+
n_phi_sample: 1 # number of samples used in ReFL
|
164 |
+
phi_sample_step: 200 # the phi log step
|
165 |
+
phi_infer_step: 50 # the phi num_inference_steps
|
166 |
+
# phi model optim
|
167 |
+
phi_optim:
|
168 |
+
name: 'adamw'
|
169 |
+
betas: [ 0.9, 0.999 ]
|
170 |
+
eps: 1e-8
|
171 |
+
weight_decay: ~ # 1e-5
|
172 |
+
# phi model lr learning schedule
|
173 |
+
phi_schedule:
|
174 |
+
use: False
|
175 |
+
name: 'cosine'
|
176 |
+
warmup_steps: 50
|
177 |
+
warmup_start_lr: 0.00001
|
178 |
+
warmup_end_lr: 0.0001
|
179 |
+
total_step: 800
|
180 |
+
cosine_end_lr: 0.0001
|
181 |
+
|
182 |
+
# reward model
|
183 |
+
reward_path: './checkpoint/ImageReward'
|
184 |
+
|
185 |
+
# xing loss for closed-form paths
|
186 |
+
xing_loss:
|
187 |
+
use: False
|
188 |
+
weight: 0.01
|
svgdreamer.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: ximing xing
|
3 |
+
# Description: the main func of this project.
|
4 |
+
# Copyright (c) 2023, XiMing Xing.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
from accelerate.utils import set_seed
|
11 |
+
import hydra
|
12 |
+
import omegaconf
|
13 |
+
|
14 |
+
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
|
15 |
+
|
16 |
+
from svgdreamer.utils import render_batch_wrap, get_seed_range
|
17 |
+
from svgdreamer.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline
|
18 |
+
|
19 |
+
|
20 |
+
@hydra.main(version_base=None, config_path="conf", config_name='config')
|
21 |
+
def main(cfg: omegaconf.DictConfig):
|
22 |
+
"""
|
23 |
+
The project configuration is stored in './conf/config.yaml’
|
24 |
+
And style configurations are stored in './conf/x/iconographic.yaml’
|
25 |
+
"""
|
26 |
+
|
27 |
+
# set seed
|
28 |
+
set_seed(cfg.seed)
|
29 |
+
seed_range = get_seed_range(cfg.srange) if cfg.multirun else None
|
30 |
+
|
31 |
+
# render function
|
32 |
+
render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range)
|
33 |
+
|
34 |
+
if not cfg.multirun: # generate SVG multiple times
|
35 |
+
pipe = SVGDreamerPipeline(cfg)
|
36 |
+
pipe.painterly_rendering(cfg.prompt)
|
37 |
+
else: # generate many SVG at once
|
38 |
+
render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None)
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
main()
|
svgdreamer/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: ximing
|
3 |
+
# Copyright (c) 2023, XiMing Xing.
|
4 |
+
# License: MIT
|
5 |
+
|
6 |
+
__version__ = "1.0"
|
svgdreamer/diffusers_warp/__init__.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
+
# Author: XiMing Xing
|
4 |
+
# Description:
|
5 |
+
from typing import AnyStr
|
6 |
+
import pathlib
|
7 |
+
from collections import OrderedDict
|
8 |
+
from packaging import version
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from diffusers import StableDiffusionPipeline, SchedulerMixin
|
12 |
+
from diffusers import UNet2DConditionModel
|
13 |
+
from diffusers.utils import is_torch_version, is_xformers_available
|
14 |
+
|
15 |
+
DiffusersModels = OrderedDict({
|
16 |
+
"sd14": "CompVis/stable-diffusion-v1-4", # resolution: 512
|
17 |
+
"sd15": "runwayml/stable-diffusion-v1-5", # resolution: 512
|
18 |
+
"sd21b": "stabilityai/stable-diffusion-2-1-base", # resolution: 512
|
19 |
+
"sd21": "stabilityai/stable-diffusion-2-1", # resolution: 768
|
20 |
+
"sdxl": "stabilityai/stable-diffusion-xl-base-1.0", # resolution: 1024
|
21 |
+
})
|
22 |
+
|
23 |
+
# default resolution
|
24 |
+
_model2resolution = {
|
25 |
+
"sd14": 512,
|
26 |
+
"sd15": 512,
|
27 |
+
"sd21b": 512,
|
28 |
+
"sd21": 768,
|
29 |
+
"sdxl": 1024,
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def model2res(model_id: str):
|
34 |
+
return _model2resolution.get(model_id, 512)
|
35 |
+
|
36 |
+
|
37 |
+
def init_StableDiffusion_pipeline(model_id: AnyStr,
|
38 |
+
custom_pipeline: StableDiffusionPipeline,
|
39 |
+
custom_scheduler: SchedulerMixin = None,
|
40 |
+
device: torch.device = "cuda",
|
41 |
+
torch_dtype: torch.dtype = torch.float32,
|
42 |
+
local_files_only: bool = True,
|
43 |
+
force_download: bool = False,
|
44 |
+
resume_download: bool = False,
|
45 |
+
ldm_speed_up: bool = False,
|
46 |
+
enable_xformers: bool = True,
|
47 |
+
gradient_checkpoint: bool = False,
|
48 |
+
cpu_offload: bool = False,
|
49 |
+
vae_slicing: bool = False,
|
50 |
+
lora_path: AnyStr = None,
|
51 |
+
unet_path: AnyStr = None) -> StableDiffusionPipeline:
|
52 |
+
"""
|
53 |
+
A tool for initial diffusers pipeline.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
|
57 |
+
custom_pipeline: any StableDiffusionPipeline pipeline
|
58 |
+
custom_scheduler: any scheduler
|
59 |
+
device: set device
|
60 |
+
torch_dtype: data type
|
61 |
+
local_files_only: prohibited download model
|
62 |
+
force_download: forced download model
|
63 |
+
resume_download: re-download model
|
64 |
+
ldm_speed_up: use the `torch.compile` api to speed up unet
|
65 |
+
enable_xformers: enable memory efficient attention from [xFormers]
|
66 |
+
gradient_checkpoint: activates gradient checkpointing for the current model
|
67 |
+
cpu_offload: enable sequential cpu offload
|
68 |
+
vae_slicing: enable sliced VAE decoding
|
69 |
+
lora_path: load LoRA checkpoint
|
70 |
+
unet_path: load unet checkpoint
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
diffusers.StableDiffusionPipeline
|
74 |
+
"""
|
75 |
+
|
76 |
+
# get model id
|
77 |
+
model_id = DiffusersModels.get(model_id, model_id)
|
78 |
+
|
79 |
+
# process diffusion model
|
80 |
+
if custom_scheduler is not None:
|
81 |
+
pipeline = custom_pipeline.from_pretrained(
|
82 |
+
model_id,
|
83 |
+
torch_dtype=torch_dtype,
|
84 |
+
local_files_only=local_files_only,
|
85 |
+
force_download=force_download,
|
86 |
+
resume_download=resume_download,
|
87 |
+
scheduler=custom_scheduler.from_pretrained(model_id,
|
88 |
+
subfolder="scheduler",
|
89 |
+
local_files_only=local_files_only,
|
90 |
+
force_download=force_download,
|
91 |
+
resume_download=resume_download)
|
92 |
+
).to(device)
|
93 |
+
else:
|
94 |
+
pipeline = custom_pipeline.from_pretrained(
|
95 |
+
model_id,
|
96 |
+
torch_dtype=torch_dtype,
|
97 |
+
local_files_only=local_files_only,
|
98 |
+
force_download=force_download,
|
99 |
+
resume_download=resume_download,
|
100 |
+
).to(device)
|
101 |
+
|
102 |
+
print(f"load diffusers pipeline: {model_id}")
|
103 |
+
|
104 |
+
# process unet model if exist
|
105 |
+
if unet_path is not None and pathlib.Path(unet_path).exists():
|
106 |
+
print(f"=> load u-net from {unet_path}")
|
107 |
+
pipeline.unet.from_pretrained(model_id, subfolder="unet")
|
108 |
+
|
109 |
+
# process lora layers if exist
|
110 |
+
if lora_path is not None and pathlib.Path(lora_path).exists():
|
111 |
+
pipeline.unet.load_attn_procs(lora_path)
|
112 |
+
print(f"=> load lora layers into U-Net from {lora_path} ...")
|
113 |
+
|
114 |
+
# torch.compile
|
115 |
+
if ldm_speed_up:
|
116 |
+
if is_torch_version(">=", "2.0.0"):
|
117 |
+
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
118 |
+
print(f"=> enable torch.compile on U-Net")
|
119 |
+
else:
|
120 |
+
print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")
|
121 |
+
|
122 |
+
# Meta xformers
|
123 |
+
if enable_xformers:
|
124 |
+
if is_xformers_available():
|
125 |
+
import xformers
|
126 |
+
|
127 |
+
xformers_version = version.parse(xformers.__version__)
|
128 |
+
if xformers_version == version.parse("0.0.16"):
|
129 |
+
print(
|
130 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. "
|
131 |
+
"If you observe problems during training, please update xFormers to at least 0.0.17. "
|
132 |
+
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
133 |
+
)
|
134 |
+
print(f"=> enable xformers")
|
135 |
+
pipeline.unet.enable_xformers_memory_efficient_attention()
|
136 |
+
else:
|
137 |
+
print(f"=> warning: xformers is not available.")
|
138 |
+
|
139 |
+
# gradient checkpointing
|
140 |
+
if gradient_checkpoint:
|
141 |
+
# if pipeline.unet.is_gradient_checkpointing:
|
142 |
+
if True:
|
143 |
+
print(f"=> enable gradient checkpointing")
|
144 |
+
pipeline.unet.enable_gradient_checkpointing()
|
145 |
+
else:
|
146 |
+
print("=> waring: gradient checkpointing is not activated for this model.")
|
147 |
+
|
148 |
+
if cpu_offload:
|
149 |
+
pipeline.enable_sequential_cpu_offload()
|
150 |
+
|
151 |
+
if vae_slicing:
|
152 |
+
pipeline.enable_vae_slicing()
|
153 |
+
|
154 |
+
print(pipeline.scheduler)
|
155 |
+
return pipeline
|
156 |
+
|
157 |
+
|
158 |
+
def init_diffusers_unet(model_id: AnyStr,
|
159 |
+
device: torch.device = "cuda",
|
160 |
+
torch_dtype: torch.dtype = torch.float32,
|
161 |
+
local_files_only: bool = True,
|
162 |
+
force_download: bool = False,
|
163 |
+
resume_download: bool = False,
|
164 |
+
ldm_speed_up: bool = False,
|
165 |
+
enable_xformers: bool = True,
|
166 |
+
gradient_checkpoint: bool = False,
|
167 |
+
lora_path: AnyStr = None,
|
168 |
+
unet_path: AnyStr = None):
|
169 |
+
"""
|
170 |
+
A tool for initial diffusers UNet model.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
|
174 |
+
device: set device
|
175 |
+
torch_dtype: data type
|
176 |
+
local_files_only: prohibited download model
|
177 |
+
force_download: forced download model
|
178 |
+
resume_download: re-download model
|
179 |
+
ldm_speed_up: use the `torch.compile` api to speed up unet
|
180 |
+
enable_xformers: enable memory efficient attention from [xFormers]
|
181 |
+
gradient_checkpoint: activates gradient checkpointing for the current model
|
182 |
+
lora_path: load LoRA checkpoint
|
183 |
+
unet_path: load unet checkpoint
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
diffusers.UNet
|
187 |
+
"""
|
188 |
+
|
189 |
+
# get model id
|
190 |
+
model_id = DiffusersModels.get(model_id, model_id)
|
191 |
+
|
192 |
+
# process UNet model
|
193 |
+
unet = UNet2DConditionModel.from_pretrained(
|
194 |
+
model_id,
|
195 |
+
subfolder="unet",
|
196 |
+
torch_dtype=torch_dtype,
|
197 |
+
local_files_only=local_files_only,
|
198 |
+
force_download=force_download,
|
199 |
+
resume_download=resume_download,
|
200 |
+
).to(device)
|
201 |
+
|
202 |
+
print(f"load diffusers UNet: {model_id}")
|
203 |
+
|
204 |
+
# process unet model if exist
|
205 |
+
if unet_path is not None and pathlib.Path(unet_path).exists():
|
206 |
+
print(f"=> load u-net from {unet_path}")
|
207 |
+
unet.from_pretrained(model_id)
|
208 |
+
|
209 |
+
# process lora layers if exist
|
210 |
+
if lora_path is not None and pathlib.Path(lora_path).exists():
|
211 |
+
unet.load_attn_procs(lora_path)
|
212 |
+
print(f"=> load lora layers into U-Net from {lora_path} ...")
|
213 |
+
|
214 |
+
# torch.compile
|
215 |
+
if ldm_speed_up:
|
216 |
+
if is_torch_version(">=", "2.0.0"):
|
217 |
+
unet = torch.compile(unet, mode="reduce-overhead", fullgraph=True)
|
218 |
+
print(f"=> enable torch.compile on U-Net")
|
219 |
+
else:
|
220 |
+
print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")
|
221 |
+
|
222 |
+
# Meta xformers
|
223 |
+
if enable_xformers:
|
224 |
+
if is_xformers_available():
|
225 |
+
import xformers
|
226 |
+
|
227 |
+
xformers_version = version.parse(xformers.__version__)
|
228 |
+
if xformers_version == version.parse("0.0.16"):
|
229 |
+
print(
|
230 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. "
|
231 |
+
"If you observe problems during training, please update xFormers to at least 0.0.17. "
|
232 |
+
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
233 |
+
)
|
234 |
+
print(f"=> enable xformers")
|
235 |
+
unet.enable_xformers_memory_efficient_attention()
|
236 |
+
else:
|
237 |
+
print(f"=> warning: xformers is not available.")
|
238 |
+
|
239 |
+
# gradient checkpointing
|
240 |
+
if gradient_checkpoint:
|
241 |
+
# if unet.is_gradient_checkpointing:
|
242 |
+
if True:
|
243 |
+
print(f"=> enable gradient checkpointing")
|
244 |
+
unet.enable_gradient_checkpointing()
|
245 |
+
else:
|
246 |
+
print("=> waring: gradient checkpointing is not activated for this model.")
|
247 |
+
|
248 |
+
return unet
|
svgdreamer/diffvg_warp/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
+
# Author: XiMing Xing
|
4 |
+
# Description:
|
5 |
+
|
6 |
+
from .diffvg_state import DiffVGState, init_pydiffvg
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'DiffVGState',
|
10 |
+
'init_pydiffvg'
|
11 |
+
]
|
svgdreamer/diffvg_warp/diffvg_state.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: ximing
|
3 |
+
# Description: parent class
|
4 |
+
# Copyright (c) 2023, XiMing Xing.
|
5 |
+
# License: MIT License
|
6 |
+
import pathlib
|
7 |
+
from typing import AnyStr, List, Union
|
8 |
+
import xml.etree.ElementTree as etree
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import pydiffvg
|
12 |
+
|
13 |
+
|
14 |
+
def init_pydiffvg(device: torch.device,
|
15 |
+
use_gpu: bool = torch.cuda.is_available(),
|
16 |
+
print_timing: bool = False):
|
17 |
+
pydiffvg.set_use_gpu(use_gpu)
|
18 |
+
pydiffvg.set_device(device)
|
19 |
+
pydiffvg.set_print_timing(print_timing)
|
20 |
+
|
21 |
+
|
22 |
+
class DiffVGState(torch.nn.Module):
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
device: torch.device,
|
26 |
+
use_gpu: bool = torch.cuda.is_available(),
|
27 |
+
print_timing: bool = False,
|
28 |
+
canvas_width: int = None,
|
29 |
+
canvas_height: int = None):
|
30 |
+
super(DiffVGState, self).__init__()
|
31 |
+
# pydiffvg device setting
|
32 |
+
self.device = device
|
33 |
+
init_pydiffvg(device, use_gpu, print_timing)
|
34 |
+
|
35 |
+
# canvas size
|
36 |
+
self.canvas_width = canvas_width
|
37 |
+
self.canvas_height = canvas_height
|
38 |
+
|
39 |
+
# record all paths
|
40 |
+
self.shapes = []
|
41 |
+
self.shape_groups = []
|
42 |
+
# record the current optimized path
|
43 |
+
self.cur_shapes = []
|
44 |
+
self.cur_shape_groups = []
|
45 |
+
|
46 |
+
# learnable SVG params
|
47 |
+
self.point_vars = []
|
48 |
+
self.color_vars = []
|
49 |
+
self.width_vars = []
|
50 |
+
|
51 |
+
def clip_curve_shape(self, *args, **kwargs):
|
52 |
+
raise NotImplementedError
|
53 |
+
|
54 |
+
def render_warp(self, seed=0):
|
55 |
+
self.clip_curve_shape()
|
56 |
+
|
57 |
+
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
58 |
+
self.canvas_width, self.canvas_height, self.shapes, self.shape_groups
|
59 |
+
)
|
60 |
+
_render = pydiffvg.RenderFunction.apply
|
61 |
+
img = _render(self.canvas_width, # width
|
62 |
+
self.canvas_height, # height
|
63 |
+
2, # num_samples_x
|
64 |
+
2, # num_samples_y
|
65 |
+
seed, # seed
|
66 |
+
None,
|
67 |
+
*scene_args)
|
68 |
+
return img
|
69 |
+
|
70 |
+
def render_image(self, canvas_width, canvas_height, shapes, shape_groups, seed: int = 0):
|
71 |
+
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
72 |
+
canvas_width, canvas_height, shapes, shape_groups
|
73 |
+
)
|
74 |
+
_render = pydiffvg.RenderFunction.apply
|
75 |
+
img = _render(canvas_width, # width
|
76 |
+
canvas_height, # height
|
77 |
+
2, # num_samples_x
|
78 |
+
2, # num_samples_y
|
79 |
+
seed, # seed
|
80 |
+
None,
|
81 |
+
*scene_args)
|
82 |
+
img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4])
|
83 |
+
img = img.unsqueeze(0) # convert img from HWC to NCHW
|
84 |
+
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
|
85 |
+
return img
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def load_svg(path_svg):
|
89 |
+
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
|
90 |
+
return canvas_width, canvas_height, shapes, shape_groups
|
91 |
+
|
92 |
+
def save_svg(self,
|
93 |
+
filename: Union[AnyStr, pathlib.Path],
|
94 |
+
width: int = None,
|
95 |
+
height: int = None,
|
96 |
+
shapes: List = None,
|
97 |
+
shape_groups: List = None,
|
98 |
+
use_gamma: bool = False,
|
99 |
+
background: str = None):
|
100 |
+
"""
|
101 |
+
Save an SVG file with specified parameters and shapes.
|
102 |
+
Noting: New version of SVG saving function that is an adaptation of pydiffvg.save_svg.
|
103 |
+
The original version saved words resulting in incomplete glyphs.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
filename (str): The path to save the SVG file.
|
107 |
+
width (int): The width of the SVG canvas.
|
108 |
+
height (int): The height of the SVG canvas.
|
109 |
+
shapes (list): A list of shapes to be included in the SVG.
|
110 |
+
shape_groups (list): A list of shape groups.
|
111 |
+
use_gamma (bool): Flag indicating whether to apply gamma correction.
|
112 |
+
background (str, optional): The background color of the SVG.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
None
|
116 |
+
"""
|
117 |
+
root = etree.Element('svg')
|
118 |
+
root.set('version', '1.1')
|
119 |
+
root.set('xmlns', 'http://www.w3.org/2000/svg')
|
120 |
+
root.set('width', str(width))
|
121 |
+
root.set('height', str(height))
|
122 |
+
|
123 |
+
if background is not None:
|
124 |
+
print(f"setting background to {background}")
|
125 |
+
root.set('style', str(background))
|
126 |
+
|
127 |
+
defs = etree.SubElement(root, 'defs')
|
128 |
+
g = etree.SubElement(root, 'g')
|
129 |
+
|
130 |
+
if use_gamma:
|
131 |
+
f = etree.SubElement(defs, 'filter')
|
132 |
+
f.set('id', 'gamma')
|
133 |
+
f.set('x', '0')
|
134 |
+
f.set('y', '0')
|
135 |
+
f.set('width', '100%')
|
136 |
+
f.set('height', '100%')
|
137 |
+
gamma = etree.SubElement(f, 'feComponentTransfer')
|
138 |
+
gamma.set('color-interpolation-filters', 'sRGB')
|
139 |
+
feFuncR = etree.SubElement(gamma, 'feFuncR')
|
140 |
+
feFuncR.set('type', 'gamma')
|
141 |
+
feFuncR.set('amplitude', str(1))
|
142 |
+
feFuncR.set('exponent', str(1 / 2.2))
|
143 |
+
feFuncG = etree.SubElement(gamma, 'feFuncG')
|
144 |
+
feFuncG.set('type', 'gamma')
|
145 |
+
feFuncG.set('amplitude', str(1))
|
146 |
+
feFuncG.set('exponent', str(1 / 2.2))
|
147 |
+
feFuncB = etree.SubElement(gamma, 'feFuncB')
|
148 |
+
feFuncB.set('type', 'gamma')
|
149 |
+
feFuncB.set('amplitude', str(1))
|
150 |
+
feFuncB.set('exponent', str(1 / 2.2))
|
151 |
+
feFuncA = etree.SubElement(gamma, 'feFuncA')
|
152 |
+
feFuncA.set('type', 'gamma')
|
153 |
+
feFuncA.set('amplitude', str(1))
|
154 |
+
feFuncA.set('exponent', str(1 / 2.2))
|
155 |
+
g.set('style', 'filter:url(#gamma)')
|
156 |
+
|
157 |
+
# Store color
|
158 |
+
for i, shape_group in enumerate(shape_groups):
|
159 |
+
def add_color(shape_color, name):
|
160 |
+
if isinstance(shape_color, pydiffvg.LinearGradient):
|
161 |
+
lg = shape_color
|
162 |
+
color = etree.SubElement(defs, 'linearGradient')
|
163 |
+
color.set('id', name)
|
164 |
+
color.set('x1', str(lg.begin[0].item()))
|
165 |
+
color.set('y1', str(lg.begin[1].item()))
|
166 |
+
color.set('x2', str(lg.end[0].item()))
|
167 |
+
color.set('y2', str(lg.end[1].item()))
|
168 |
+
offsets = lg.offsets.data.cpu().numpy()
|
169 |
+
stop_colors = lg.stop_colors.data.cpu().numpy()
|
170 |
+
for j in range(offsets.shape[0]):
|
171 |
+
stop = etree.SubElement(color, 'stop')
|
172 |
+
stop.set('offset', str(offsets[j]))
|
173 |
+
c = lg.stop_colors[j, :]
|
174 |
+
stop.set('stop-color', 'rgb({}, {}, {})'.format(
|
175 |
+
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
|
176 |
+
))
|
177 |
+
stop.set('stop-opacity', '{}'.format(c[3]))
|
178 |
+
if isinstance(shape_color, pydiffvg.RadialGradient):
|
179 |
+
lg = shape_color
|
180 |
+
color = etree.SubElement(defs, 'radialGradient')
|
181 |
+
color.set('id', name)
|
182 |
+
color.set('cx', str(lg.center[0].item() / width))
|
183 |
+
color.set('cy', str(lg.center[1].item() / height))
|
184 |
+
# this only support width=height
|
185 |
+
color.set('r', str(lg.radius[0].item() / width))
|
186 |
+
offsets = lg.offsets.data.cpu().numpy()
|
187 |
+
stop_colors = lg.stop_colors.data.cpu().numpy()
|
188 |
+
for j in range(offsets.shape[0]):
|
189 |
+
stop = etree.SubElement(color, 'stop')
|
190 |
+
stop.set('offset', str(offsets[j]))
|
191 |
+
c = lg.stop_colors[j, :]
|
192 |
+
stop.set('stop-color', 'rgb({}, {}, {})'.format(
|
193 |
+
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
|
194 |
+
))
|
195 |
+
stop.set('stop-opacity', '{}'.format(c[3]))
|
196 |
+
|
197 |
+
if shape_group.fill_color is not None:
|
198 |
+
add_color(shape_group.fill_color, 'shape_{}_fill'.format(i))
|
199 |
+
if shape_group.stroke_color is not None:
|
200 |
+
add_color(shape_group.stroke_color, 'shape_{}_stroke'.format(i))
|
201 |
+
|
202 |
+
for i, shape_group in enumerate(shape_groups):
|
203 |
+
shape = shapes[shape_group.shape_ids[0]]
|
204 |
+
if isinstance(shape, pydiffvg.Circle):
|
205 |
+
shape_node = etree.SubElement(g, 'circle')
|
206 |
+
shape_node.set('r', str(shape.radius.item()))
|
207 |
+
shape_node.set('cx', str(shape.center[0].item()))
|
208 |
+
shape_node.set('cy', str(shape.center[1].item()))
|
209 |
+
elif isinstance(shape, pydiffvg.Polygon):
|
210 |
+
shape_node = etree.SubElement(g, 'polygon')
|
211 |
+
points = shape.points.data.cpu().numpy()
|
212 |
+
path_str = ''
|
213 |
+
for j in range(0, shape.points.shape[0]):
|
214 |
+
path_str += '{} {}'.format(points[j, 0], points[j, 1])
|
215 |
+
if j != shape.points.shape[0] - 1:
|
216 |
+
path_str += ' '
|
217 |
+
shape_node.set('points', path_str)
|
218 |
+
elif isinstance(shape, pydiffvg.Path):
|
219 |
+
for j, id in enumerate(shape_group.shape_ids):
|
220 |
+
shape = shapes[id]
|
221 |
+
if isinstance(shape, pydiffvg.Path):
|
222 |
+
if j == 0:
|
223 |
+
shape_node = etree.SubElement(g, 'path')
|
224 |
+
node_id = shape_node.get('id')
|
225 |
+
path_str = ''
|
226 |
+
|
227 |
+
num_segments = shape.num_control_points.shape[0]
|
228 |
+
num_control_points = shape.num_control_points.data.cpu().numpy()
|
229 |
+
points = shape.points.data.cpu().numpy()
|
230 |
+
num_points = shape.points.shape[0]
|
231 |
+
path_str += 'M {} {}'.format(points[0, 0], points[0, 1])
|
232 |
+
point_id = 1
|
233 |
+
for j in range(0, num_segments):
|
234 |
+
if num_control_points[j] == 0:
|
235 |
+
p = point_id % num_points
|
236 |
+
path_str += ' L {} {}'.format(
|
237 |
+
points[p, 0], points[p, 1])
|
238 |
+
point_id += 1
|
239 |
+
elif num_control_points[j] == 1:
|
240 |
+
p1 = (point_id + 1) % num_points
|
241 |
+
path_str += ' Q {} {} {} {}'.format(
|
242 |
+
points[point_id, 0], points[point_id, 1],
|
243 |
+
points[p1, 0], points[p1, 1])
|
244 |
+
point_id += 2
|
245 |
+
elif num_control_points[j] == 2:
|
246 |
+
p2 = (point_id + 2) % num_points
|
247 |
+
path_str += ' C {} {} {} {} {} {}'.format(
|
248 |
+
points[point_id, 0], points[point_id, 1],
|
249 |
+
points[point_id + 1, 0], points[point_id + 1, 1],
|
250 |
+
points[p2, 0], points[p2, 1])
|
251 |
+
point_id += 3
|
252 |
+
if node_id is not None:
|
253 |
+
shape_node.set('id', node_id) # add id to Path
|
254 |
+
shape_node.set('d', path_str)
|
255 |
+
elif isinstance(shape, pydiffvg.Rect):
|
256 |
+
shape_node = etree.SubElement(g, 'rect')
|
257 |
+
shape_node.set('x', str(shape.p_min[0].item()))
|
258 |
+
shape_node.set('y', str(shape.p_min[1].item()))
|
259 |
+
shape_node.set('width', str(shape.p_max[0].item() - shape.p_min[0].item()))
|
260 |
+
shape_node.set('height', str(shape.p_max[1].item() - shape.p_min[1].item()))
|
261 |
+
elif isinstance(shape, pydiffvg.Ellipse):
|
262 |
+
shape_node = etree.SubElement(g, 'ellipse')
|
263 |
+
shape_node.set('cx', str(shape.center[0].item()))
|
264 |
+
shape_node.set('cy', str(shape.center[1].item()))
|
265 |
+
shape_node.set('rx', str(shape.radius[0].item()))
|
266 |
+
shape_node.set('ry', str(shape.radius[1].item()))
|
267 |
+
else:
|
268 |
+
raise NotImplementedError(f'shape type: {type(shape)} is not involved in pydiffvg.')
|
269 |
+
|
270 |
+
shape_node.set('stroke-width', str(2 * shape.stroke_width.data.cpu().item()))
|
271 |
+
if shape_group.fill_color is not None:
|
272 |
+
if isinstance(shape_group.fill_color, pydiffvg.LinearGradient):
|
273 |
+
shape_node.set('fill', 'url(#shape_{}_fill)'.format(i))
|
274 |
+
else:
|
275 |
+
c = shape_group.fill_color.data.cpu().numpy()
|
276 |
+
shape_node.set('fill', 'rgb({}, {}, {})'.format(
|
277 |
+
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
|
278 |
+
shape_node.set('opacity', str(c[3]))
|
279 |
+
else:
|
280 |
+
shape_node.set('fill', 'none')
|
281 |
+
if shape_group.stroke_color is not None:
|
282 |
+
if isinstance(shape_group.stroke_color, pydiffvg.LinearGradient):
|
283 |
+
shape_node.set('stroke', 'url(#shape_{}_stroke)'.format(i))
|
284 |
+
else:
|
285 |
+
c = shape_group.stroke_color.data.cpu().numpy()
|
286 |
+
shape_node.set('stroke', 'rgb({}, {}, {})'.format(
|
287 |
+
int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
|
288 |
+
shape_node.set('stroke-opacity', str(c[3]))
|
289 |
+
shape_node.set('stroke-linecap', 'round')
|
290 |
+
shape_node.set('stroke-linejoin', 'round')
|
291 |
+
|
292 |
+
with open(filename, "w") as f:
|
293 |
+
f.write(pydiffvg.prettify(root))
|
294 |
+
|
295 |
+
@staticmethod
|
296 |
+
def save_image(img, filename, gamma=1):
|
297 |
+
if torch.is_tensor(img) and torch.device != 'cpu':
|
298 |
+
img = img.detach().cpu()
|
299 |
+
pydiffvg.imwrite(img, filename, gamma=gamma)
|
svgdreamer/libs/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
+
# Author: XiMing Xing
|
4 |
+
# Description: a self consistent system,
|
5 |
+
# including runner, trainer, loss function, EMA, optimizer, lr scheduler , and common utils.
|
6 |
+
|
7 |
+
from .model_state import ModelState
|
8 |
+
from .optim import get_optimizer
|
svgdreamer/libs/logging.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
+
# Author: XiMing Xing
|
4 |
+
# Description:
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import errno
|
9 |
+
|
10 |
+
|
11 |
+
def get_logger(logs_dir: str, file_name: str = "log.txt"):
|
12 |
+
logger = PrintLogger(os.path.join(logs_dir, file_name))
|
13 |
+
sys.stdout = logger # record all python print
|
14 |
+
return logger
|
15 |
+
|
16 |
+
|
17 |
+
class PrintLogger(object):
|
18 |
+
|
19 |
+
def __init__(self, fpath=None):
|
20 |
+
"""
|
21 |
+
python standard input/output records
|
22 |
+
"""
|
23 |
+
self.console = sys.stdout
|
24 |
+
self.file = None
|
25 |
+
if fpath is not None:
|
26 |
+
mkdir_if_missing(os.path.dirname(fpath))
|
27 |
+
self.file = open(fpath, 'w')
|
28 |
+
|
29 |
+
def __del__(self):
|
30 |
+
self.close()
|
31 |
+
|
32 |
+
def __enter__(self):
|
33 |
+
pass
|
34 |
+
|
35 |
+
def __exit__(self, *args):
|
36 |
+
self.close()
|
37 |
+
|
38 |
+
def write(self, msg):
|
39 |
+
self.console.write(msg)
|
40 |
+
if self.file is not None:
|
41 |
+
self.file.write(msg)
|
42 |
+
|
43 |
+
def write_in(self, msg):
|
44 |
+
"""write in log only, not console"""
|
45 |
+
if self.file is not None:
|
46 |
+
self.file.write(msg)
|
47 |
+
|
48 |
+
def flush(self):
|
49 |
+
self.console.flush()
|
50 |
+
if self.file is not None:
|
51 |
+
self.file.flush()
|
52 |
+
os.fsync(self.file.fileno())
|
53 |
+
|
54 |
+
def close(self):
|
55 |
+
self.console.close()
|
56 |
+
if self.file is not None:
|
57 |
+
self.file.close()
|
58 |
+
|
59 |
+
|
60 |
+
def mkdir_if_missing(dir_path):
|
61 |
+
try:
|
62 |
+
os.makedirs(dir_path)
|
63 |
+
except OSError as e:
|
64 |
+
if e.errno != errno.EEXIST:
|
65 |
+
raise
|
svgdreamer/libs/model_state.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
+
# Author: XiMing Xing
|
4 |
+
# Description:
|
5 |
+
|
6 |
+
from typing import Union, List
|
7 |
+
from pathlib import Path
|
8 |
+
from datetime import datetime
|
9 |
+
import logging
|
10 |
+
|
11 |
+
from omegaconf import OmegaConf, DictConfig
|
12 |
+
from pprint import pprint
|
13 |
+
import torch
|
14 |
+
from accelerate import Accelerator
|
15 |
+
|
16 |
+
from .logging import get_logger
|
17 |
+
|
18 |
+
|
19 |
+
class ModelState:
|
20 |
+
"""
|
21 |
+
Handling logger and `hugging face` accelerate training
|
22 |
+
|
23 |
+
features:
|
24 |
+
- Precision
|
25 |
+
- Device
|
26 |
+
- Optimizer
|
27 |
+
- Logger (default: python system print and logging)
|
28 |
+
- Monitor (default: wandb, tensorboard)
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
args: DictConfig,
|
34 |
+
log_path_suffix: str = None,
|
35 |
+
ignore_log=False, # whether to create log file or not
|
36 |
+
) -> None:
|
37 |
+
self.args: DictConfig = args
|
38 |
+
# set cfg
|
39 |
+
self.state_cfg = args.state
|
40 |
+
self.x_cfg = args.x
|
41 |
+
|
42 |
+
"""check valid"""
|
43 |
+
mixed_precision = self.state_cfg.get("mprec")
|
44 |
+
# Bug: omegaconf convert 'no' to false
|
45 |
+
mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision
|
46 |
+
|
47 |
+
"""create working space"""
|
48 |
+
# rule: ['./config'. 'method_name', 'exp_name.yaml']
|
49 |
+
# -> result_path: ./runs/{method_name}-{exp_name}, as a base folder
|
50 |
+
now_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
|
51 |
+
results_folder = self.args.get("result_path", None)
|
52 |
+
if results_folder is None:
|
53 |
+
self.result_path = Path("./workdir") / f"SVGDreamer-{now_time}"
|
54 |
+
else:
|
55 |
+
self.result_path = Path(results_folder) / f"SVGDreamer-{now_time}"
|
56 |
+
|
57 |
+
# update result_path: ./runs/{method_name}-{exp_name}/{log_path_suffix}
|
58 |
+
# noting: can be understood as "results dir / methods / ablation study / your result"
|
59 |
+
if log_path_suffix is not None:
|
60 |
+
self.result_path = self.result_path / f"{log_path_suffix}"
|
61 |
+
else:
|
62 |
+
self.result_path = self.result_path / f"SVGDreamer"
|
63 |
+
|
64 |
+
"""init visualized tracker"""
|
65 |
+
# TODO: monitor with WANDB or TENSORBOARD
|
66 |
+
self.log_with = []
|
67 |
+
# if self.state_cfg.wandb:
|
68 |
+
# self.log_with.append(LoggerType.WANDB)
|
69 |
+
# if self.state_cfg.tensorboard:
|
70 |
+
# self.log_with.append(LoggerType.TENSORBOARD)
|
71 |
+
|
72 |
+
"""HuggingFace Accelerator"""
|
73 |
+
self.accelerator = Accelerator(
|
74 |
+
device_placement=True,
|
75 |
+
mixed_precision=mixed_precision,
|
76 |
+
cpu=True if self.state_cfg.cpu else False,
|
77 |
+
log_with=None if len(self.log_with) == 0 else self.log_with,
|
78 |
+
project_dir=self.result_path / "vis",
|
79 |
+
)
|
80 |
+
|
81 |
+
"""logs"""
|
82 |
+
if self.accelerator.is_local_main_process:
|
83 |
+
# logging
|
84 |
+
self.log = logging.getLogger(__name__)
|
85 |
+
|
86 |
+
# log results in a folder periodically
|
87 |
+
self.result_path.mkdir(parents=True, exist_ok=True)
|
88 |
+
if not ignore_log:
|
89 |
+
self.logger = get_logger(
|
90 |
+
logs_dir=self.result_path.as_posix(),
|
91 |
+
file_name=f"{now_time}-{args.seed}-log.txt"
|
92 |
+
)
|
93 |
+
|
94 |
+
print("==> system args: ")
|
95 |
+
sys_cfg = OmegaConf.masked_copy(args, ["x"])
|
96 |
+
print(sys_cfg)
|
97 |
+
print("==> yaml config args: ")
|
98 |
+
print(self.x_cfg)
|
99 |
+
|
100 |
+
print("\n***** Model State *****")
|
101 |
+
print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp}")
|
102 |
+
print(f"-> Weight dtype: {self.weight_dtype}")
|
103 |
+
|
104 |
+
if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled:
|
105 |
+
print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}")
|
106 |
+
|
107 |
+
print(f"-> Working Space: '{self.result_path}'")
|
108 |
+
|
109 |
+
"""glob step"""
|
110 |
+
self.step = 0
|
111 |
+
|
112 |
+
"""log process"""
|
113 |
+
self.accelerator.wait_for_everyone()
|
114 |
+
print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}')
|
115 |
+
|
116 |
+
self.print("-> state initialization complete \n")
|
117 |
+
|
118 |
+
@property
|
119 |
+
def device(self):
|
120 |
+
return self.accelerator.device
|
121 |
+
|
122 |
+
@property
|
123 |
+
def is_main_process(self):
|
124 |
+
return self.accelerator.is_main_process
|
125 |
+
|
126 |
+
@property
|
127 |
+
def weight_dtype(self):
|
128 |
+
weight_dtype = torch.float32
|
129 |
+
if self.accelerator.mixed_precision == "fp16":
|
130 |
+
weight_dtype = torch.float16
|
131 |
+
elif self.accelerator.mixed_precision == "bf16":
|
132 |
+
weight_dtype = torch.bfloat16
|
133 |
+
return weight_dtype
|
134 |
+
|
135 |
+
@property
|
136 |
+
def n_gpus(self):
|
137 |
+
return self.accelerator.num_processes
|
138 |
+
|
139 |
+
@property
|
140 |
+
def no_decay_params_names(self):
|
141 |
+
no_decay = [
|
142 |
+
"bn", "LayerNorm", "GroupNorm",
|
143 |
+
]
|
144 |
+
return no_decay
|
145 |
+
|
146 |
+
def no_decay_params(self, model, weight_decay):
|
147 |
+
"""optimization tricks"""
|
148 |
+
optimizer_grouped_parameters = [
|
149 |
+
{
|
150 |
+
"params": [
|
151 |
+
p for n, p in model.named_parameters()
|
152 |
+
if not any(nd in n for nd in self.no_decay_params_names)
|
153 |
+
],
|
154 |
+
"weight_decay": weight_decay,
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"params": [
|
158 |
+
p for n, p in model.named_parameters()
|
159 |
+
if any(nd in n for nd in self.no_decay_params_names)
|
160 |
+
],
|
161 |
+
"weight_decay": 0.0,
|
162 |
+
},
|
163 |
+
]
|
164 |
+
return optimizer_grouped_parameters
|
165 |
+
|
166 |
+
def optimized_params(self, model: torch.nn.Module, verbose=True) -> List:
|
167 |
+
"""return parameters if `requires_grad` is True
|
168 |
+
|
169 |
+
Args:
|
170 |
+
model: pytorch models
|
171 |
+
verbose: log optimized parameters
|
172 |
+
|
173 |
+
Examples:
|
174 |
+
>>> params_optimized = self.optimized_params(uvit, verbose=True)
|
175 |
+
>>> optimizer = torch.optim.AdamW(params_optimized, lr=1e-3)
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
a list of parameters
|
179 |
+
"""
|
180 |
+
params_optimized = []
|
181 |
+
for key, value in model.named_parameters():
|
182 |
+
if value.requires_grad:
|
183 |
+
params_optimized.append(value)
|
184 |
+
if verbose:
|
185 |
+
self.print("\t {}, {}, {}".format(key, value.numel(), value.shape))
|
186 |
+
return params_optimized
|
187 |
+
|
188 |
+
def save_everything(self, fpath: str):
|
189 |
+
"""Saving and loading the model, optimizer, RNG generators, and the GradScaler."""
|
190 |
+
if not self.accelerator.is_main_process:
|
191 |
+
return
|
192 |
+
self.accelerator.save_state(fpath)
|
193 |
+
|
194 |
+
def load_save_everything(self, fpath: str):
|
195 |
+
"""Loading the model, optimizer, RNG generators, and the GradScaler."""
|
196 |
+
self.accelerator.load_state(fpath)
|
197 |
+
|
198 |
+
def save(self, milestone: Union[str, float, int], checkpoint: object) -> None:
|
199 |
+
if not self.accelerator.is_main_process:
|
200 |
+
return
|
201 |
+
|
202 |
+
torch.save(checkpoint, self.result_path / f'model-{milestone}.pt')
|
203 |
+
|
204 |
+
def save_in(self, root: Union[str, Path], checkpoint: object) -> None:
|
205 |
+
if not self.accelerator.is_main_process:
|
206 |
+
return
|
207 |
+
|
208 |
+
torch.save(checkpoint, root)
|
209 |
+
|
210 |
+
def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False):
|
211 |
+
ckpt = torch.load(path, map_location=self.device)
|
212 |
+
|
213 |
+
unwrapped_model = self.accelerator.unwrap_model(model)
|
214 |
+
if rm_module_prefix:
|
215 |
+
unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()})
|
216 |
+
else:
|
217 |
+
unwrapped_model.load_state_dict(ckpt)
|
218 |
+
return unwrapped_model
|
219 |
+
|
220 |
+
def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]):
|
221 |
+
ckpt = torch.load(path, map_location=self.accelerator.device)
|
222 |
+
self.print(f"pretrained_dict len: {len(ckpt)}")
|
223 |
+
unwrapped_model = self.accelerator.unwrap_model(model)
|
224 |
+
model_dict = unwrapped_model.state_dict()
|
225 |
+
pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict}
|
226 |
+
model_dict.update(pretrained_dict)
|
227 |
+
unwrapped_model.load_state_dict(model_dict, strict=False)
|
228 |
+
self.print(f"selected pretrained_dict: {len(model_dict)}")
|
229 |
+
return unwrapped_model
|
230 |
+
|
231 |
+
def print(self, *args, **kwargs):
|
232 |
+
"""Use in replacement of `print()` to only print once per server."""
|
233 |
+
self.accelerator.print(*args, **kwargs)
|
234 |
+
|
235 |
+
def pretty_print(self, msg):
|
236 |
+
if self.accelerator.is_main_process:
|
237 |
+
pprint(dict(msg))
|
238 |
+
|
239 |
+
def close_tracker(self):
|
240 |
+
self.accelerator.end_training()
|
241 |
+
|
242 |
+
def free_memory(self):
|
243 |
+
self.accelerator.clear()
|
244 |
+
|
245 |
+
def close(self, msg: str = "Training complete."):
|
246 |
+
"""Use in end of training."""
|
247 |
+
self.free_memory()
|
248 |
+
|
249 |
+
if torch.cuda.is_available():
|
250 |
+
self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
|
251 |
+
if len(self.log_with) > 0:
|
252 |
+
self.close_tracker()
|
253 |
+
self.print(msg)
|
svgdreamer/libs/optim.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: ximing
|
3 |
+
# Description: optimizers
|
4 |
+
# Copyright (c) 2023, XiMing Xing.
|
5 |
+
# License: MIT License
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from omegaconf import DictConfig
|
10 |
+
|
11 |
+
|
12 |
+
def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None):
|
13 |
+
param_dict = {}
|
14 |
+
if optimizer_name == "adam":
|
15 |
+
optimizer = partial(torch.optim.Adam, params=parameters)
|
16 |
+
if lr is not None:
|
17 |
+
optimizer = partial(torch.optim.Adam, params=parameters, lr=lr)
|
18 |
+
if config.get('betas'):
|
19 |
+
param_dict['betas'] = config.betas
|
20 |
+
if config.get('weight_decay'):
|
21 |
+
param_dict['weight_decay'] = config.weight_decay
|
22 |
+
if config.get('eps'):
|
23 |
+
param_dict['eps'] = config.eps
|
24 |
+
elif optimizer_name == "adamW":
|
25 |
+
optimizer = partial(torch.optim.AdamW, params=parameters)
|
26 |
+
if lr is not None:
|
27 |
+
optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr)
|
28 |
+
if config.get('betas'):
|
29 |
+
param_dict['betas'] = config.betas
|
30 |
+
if config.get('weight_decay'):
|
31 |
+
param_dict['weight_decay'] = config.weight_decay
|
32 |
+
if config.get('eps'):
|
33 |
+
param_dict['eps'] = config.eps
|
34 |
+
elif optimizer_name == "radam":
|
35 |
+
optimizer = partial(torch.optim.RAdam, params=parameters)
|
36 |
+
if lr is not None:
|
37 |
+
optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr)
|
38 |
+
if config.get('betas'):
|
39 |
+
param_dict['betas'] = config.betas
|
40 |
+
if config.get('weight_decay'):
|
41 |
+
param_dict['weight_decay'] = config.weight_decay
|
42 |
+
elif optimizer_name == "sgd":
|
43 |
+
optimizer = partial(torch.optim.SGD, params=parameters)
|
44 |
+
if lr is not None:
|
45 |
+
optimizer = partial(torch.optim.SGD, params=parameters, lr=lr)
|
46 |
+
if config.get('momentum'):
|
47 |
+
param_dict['momentum'] = config.momentum
|
48 |
+
if config.get('weight_decay'):
|
49 |
+
param_dict['weight_decay'] = config.weight_decay
|
50 |
+
if config.get('nesterov'):
|
51 |
+
param_dict['nesterov'] = config.nesterov
|
52 |
+
else:
|
53 |
+
raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.")
|
54 |
+
|
55 |
+
if len(param_dict.keys()) > 0:
|
56 |
+
return optimizer(**param_dict)
|
57 |
+
else:
|
58 |
+
return optimizer()
|
svgdreamer/painter/VPSD_pipeline.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
+
# Author: XiMing Xing
|
4 |
+
# Description:
|
5 |
+
import re
|
6 |
+
import PIL
|
7 |
+
from PIL import Image
|
8 |
+
from typing import Any, List, Optional, Union, Dict
|
9 |
+
from omegaconf import DictConfig
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torchvision import transforms
|
15 |
+
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
16 |
+
from diffusers import DDIMScheduler
|
17 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
18 |
+
rescale_noise_cfg, StableDiffusionPipelineOutput)
|
19 |
+
from diffusers.models.attention_processor import LoRAAttnProcessor
|
20 |
+
from diffusers.loaders import AttnProcsLayers
|
21 |
+
|
22 |
+
from svgdreamer.diffusers_warp import init_StableDiffusion_pipeline, init_diffusers_unet
|
23 |
+
|
24 |
+
|
25 |
+
class VectorizedParticleSDSPipeline(torch.nn.Module):
|
26 |
+
|
27 |
+
def __init__(self, model_cfg: DictConfig, diffuser_cfg: DictConfig, guidance_cfg: DictConfig, device: torch.device):
|
28 |
+
super().__init__()
|
29 |
+
self.device = device
|
30 |
+
assert guidance_cfg.n_particle >= guidance_cfg.vsd_n_particle
|
31 |
+
assert guidance_cfg.n_particle >= guidance_cfg.phi_n_particle
|
32 |
+
|
33 |
+
pipe_kwargs = {
|
34 |
+
"device": self.device,
|
35 |
+
"torch_dtype": torch.float32,
|
36 |
+
"local_files_only": not diffuser_cfg.download,
|
37 |
+
"force_download": diffuser_cfg.force_download,
|
38 |
+
"resume_download": diffuser_cfg.resume_download,
|
39 |
+
"ldm_speed_up": model_cfg.ldm_speed_up,
|
40 |
+
"enable_xformers": model_cfg.enable_xformers,
|
41 |
+
"gradient_checkpoint": model_cfg.gradient_checkpoint,
|
42 |
+
"cpu_offload": model_cfg.cpu_offload,
|
43 |
+
"vae_slicing": False
|
44 |
+
}
|
45 |
+
|
46 |
+
# load pretrained model
|
47 |
+
self.sd_pipeline = init_StableDiffusion_pipeline(
|
48 |
+
model_cfg.model_id,
|
49 |
+
custom_pipeline=StableDiffusionPipeline,
|
50 |
+
custom_scheduler=DDIMScheduler,
|
51 |
+
**pipe_kwargs
|
52 |
+
)
|
53 |
+
# disable grads
|
54 |
+
self.sd_pipeline.vae.requires_grad_(False)
|
55 |
+
self.sd_pipeline.text_encoder.requires_grad_(False)
|
56 |
+
self.sd_pipeline.unet.requires_grad_(False)
|
57 |
+
# set components
|
58 |
+
self.vae = self.sd_pipeline.vae
|
59 |
+
self.unet = self.sd_pipeline.unet
|
60 |
+
self.scheduler = self.sd_pipeline.scheduler
|
61 |
+
self.tokenizer = self.sd_pipeline.tokenizer
|
62 |
+
self.text_encoder = self.sd_pipeline.text_encoder
|
63 |
+
|
64 |
+
if guidance_cfg.phi_model == 'lora':
|
65 |
+
if guidance_cfg.phi_single: # default, use the single unet
|
66 |
+
# load LoRA model from the pretrained model
|
67 |
+
unet_ = self.unet
|
68 |
+
else:
|
69 |
+
# create a new unet model
|
70 |
+
pipe_kwargs.pop('cpu_offload')
|
71 |
+
pipe_kwargs.pop('vae_slicing')
|
72 |
+
unet_ = init_diffusers_unet(model_cfg.model_id, **pipe_kwargs)
|
73 |
+
|
74 |
+
# set correct LoRA layers
|
75 |
+
self.unet_phi, phi_model_layers = self.set_lora_layers(unet_)
|
76 |
+
self.phi_params = list(phi_model_layers.parameters())
|
77 |
+
self.lora_cross_attention_kwargs = {"scale": guidance_cfg.lora_attn_scale} \
|
78 |
+
if guidance_cfg.use_attn_scale else {}
|
79 |
+
self.vae_phi = self.vae
|
80 |
+
self.vae_phi.requires_grad_(False)
|
81 |
+
|
82 |
+
elif guidance_cfg.phi_model == 'unet_simple':
|
83 |
+
self.unet_phi = UNet2DConditionModel(
|
84 |
+
sample_size=64,
|
85 |
+
in_channels=4,
|
86 |
+
out_channels=4,
|
87 |
+
layers_per_block=1,
|
88 |
+
block_out_channels=(128, 256, 384, 512),
|
89 |
+
down_block_types=(
|
90 |
+
"DownBlock2D",
|
91 |
+
"AttnDownBlock2D",
|
92 |
+
"AttnDownBlock2D",
|
93 |
+
"AttnDownBlock2D",
|
94 |
+
),
|
95 |
+
up_block_types=(
|
96 |
+
"AttnUpBlock2D",
|
97 |
+
"AttnUpBlock2D",
|
98 |
+
"AttnUpBlock2D",
|
99 |
+
"UpBlock2D",
|
100 |
+
),
|
101 |
+
cross_attention_dim=self.unet.config.cross_attention_dim
|
102 |
+
).to(device)
|
103 |
+
self.phi_params = list(self.unet_phi.parameters())
|
104 |
+
self.vae_phi = self.vae
|
105 |
+
# reset lora
|
106 |
+
guidance_cfg.use_attn_scale = False
|
107 |
+
guidance_cfg.lora_attn_scale = False
|
108 |
+
|
109 |
+
# hyper-params
|
110 |
+
self.phi_single = guidance_cfg.phi_single
|
111 |
+
self.guidance_scale: float = guidance_cfg.guidance_scale
|
112 |
+
self.guidance_scale_lora: float = guidance_cfg.phi_guidance_scale
|
113 |
+
self.grad_clip_val: Union[float, None] = guidance_cfg.grad_clip_val
|
114 |
+
self.vsd_n_particle: int = guidance_cfg.vsd_n_particle
|
115 |
+
self.phi_n_particle: int = guidance_cfg.phi_n_particle
|
116 |
+
self.t_schedule: str = guidance_cfg.t_schedule
|
117 |
+
self.t_range = list(guidance_cfg.t_range)
|
118 |
+
print(
|
119 |
+
f'n_particles: {guidance_cfg.n_particle}, '
|
120 |
+
f'enhance_particles: {guidance_cfg.particle_aug}, '
|
121 |
+
f'n_particles of score: {self.vsd_n_particle}, '
|
122 |
+
f'n_particles of phi_model: {self.phi_n_particle}, \n'
|
123 |
+
f't_range: {self.t_range}, '
|
124 |
+
f't_schedule: {self.t_schedule}, \n'
|
125 |
+
f'guidance_scale: {self.guidance_scale}, phi_guidance_scale: {self.guidance_scale_lora}.'
|
126 |
+
)
|
127 |
+
print(f"phi_model: {guidance_cfg.phi_model}, "
|
128 |
+
f"use lora_cross_attn: {guidance_cfg.use_attn_scale}, "
|
129 |
+
f"lora_attn_scale: {guidance_cfg.lora_attn_scale}. \n")
|
130 |
+
|
131 |
+
# for convenience
|
132 |
+
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
133 |
+
self.alphas = self.scheduler.alphas_cumprod.to(self.device)
|
134 |
+
self.text_embeddings = None
|
135 |
+
self.text_embedd_cond, self.text_embedd_uncond = None, None
|
136 |
+
self.text_embeddings_phi = None
|
137 |
+
self.t = None
|
138 |
+
|
139 |
+
def set_lora_layers(self, unet): # set correct lora layers
|
140 |
+
lora_attn_procs = {}
|
141 |
+
for name in unet.attn_processors.keys():
|
142 |
+
cross_attention_dim = None if name.endswith("attn1.processor") \
|
143 |
+
else unet.config.cross_attention_dim
|
144 |
+
if name.startswith("mid_block"):
|
145 |
+
hidden_size = unet.config.block_out_channels[-1]
|
146 |
+
elif name.startswith("up_blocks"):
|
147 |
+
block_id = int(name[len("up_blocks.")])
|
148 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
149 |
+
elif name.startswith("down_blocks"):
|
150 |
+
block_id = int(name[len("down_blocks.")])
|
151 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
152 |
+
|
153 |
+
lora_attn_procs[name] = LoRAAttnProcessor(
|
154 |
+
hidden_size=hidden_size,
|
155 |
+
cross_attention_dim=cross_attention_dim
|
156 |
+
).to(self.device)
|
157 |
+
unet.set_attn_processor(lora_attn_procs)
|
158 |
+
lora_layers = AttnProcsLayers(unet.attn_processors)
|
159 |
+
|
160 |
+
unet.requires_grad_(False)
|
161 |
+
for param in lora_layers.parameters():
|
162 |
+
param.requires_grad_(True)
|
163 |
+
return unet, lora_layers
|
164 |
+
|
165 |
+
@torch.no_grad()
|
166 |
+
def encode_prompt(self,
|
167 |
+
prompt,
|
168 |
+
device,
|
169 |
+
do_classifier_free_guidance,
|
170 |
+
negative_prompt=None):
|
171 |
+
# text conditional embed
|
172 |
+
text_inputs = self.tokenizer(
|
173 |
+
prompt,
|
174 |
+
padding="max_length",
|
175 |
+
max_length=self.tokenizer.model_max_length,
|
176 |
+
truncation=True,
|
177 |
+
return_tensors="pt",
|
178 |
+
)
|
179 |
+
prompt_embeds = self.text_encoder(text_inputs.input_ids.to(device))[0]
|
180 |
+
|
181 |
+
if do_classifier_free_guidance:
|
182 |
+
if negative_prompt is None:
|
183 |
+
uncond_tokens = [""]
|
184 |
+
elif isinstance(negative_prompt, str):
|
185 |
+
uncond_tokens = [negative_prompt]
|
186 |
+
else:
|
187 |
+
uncond_tokens = negative_prompt
|
188 |
+
|
189 |
+
# unconditional embed
|
190 |
+
uncond_input = self.tokenizer(
|
191 |
+
uncond_tokens,
|
192 |
+
padding="max_length",
|
193 |
+
max_length=prompt_embeds.shape[1],
|
194 |
+
truncation=True,
|
195 |
+
return_tensors="pt",
|
196 |
+
)
|
197 |
+
negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
198 |
+
|
199 |
+
concat_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
200 |
+
return concat_prompt_embeds, negative_prompt_embeds, prompt_embeds
|
201 |
+
|
202 |
+
return prompt_embeds, None, None
|
203 |
+
|
204 |
+
def sampling(self,
|
205 |
+
vae,
|
206 |
+
unet,
|
207 |
+
scheduler,
|
208 |
+
prompt: Union[str, List[str]] = None,
|
209 |
+
height: Optional[int] = None,
|
210 |
+
width: Optional[int] = None,
|
211 |
+
num_inference_steps: int = 50,
|
212 |
+
guidance_scale: float = 7.5,
|
213 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
214 |
+
num_images_per_prompt: Optional[int] = 1,
|
215 |
+
eta: float = 0.0,
|
216 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
217 |
+
latents: Optional[torch.FloatTensor] = None,
|
218 |
+
output_type: Optional[str] = "pil",
|
219 |
+
return_dict: bool = True,
|
220 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
221 |
+
guidance_rescale: float = 0.0):
|
222 |
+
|
223 |
+
# 0. Default height and width to unet
|
224 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
225 |
+
height = height or unet.config.sample_size * vae_scale_factor
|
226 |
+
width = width or unet.config.sample_size * vae_scale_factor
|
227 |
+
|
228 |
+
# 2. Define call parameters
|
229 |
+
if prompt is not None and isinstance(prompt, list):
|
230 |
+
batch_size = len(prompt)
|
231 |
+
else:
|
232 |
+
batch_size = 1
|
233 |
+
|
234 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
235 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
236 |
+
# corresponds to doing no classifier free guidance.
|
237 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
238 |
+
|
239 |
+
# 3. Encode input prompt
|
240 |
+
prompt_embeds, _, _ = self.encode_prompt(
|
241 |
+
prompt,
|
242 |
+
self.device,
|
243 |
+
do_classifier_free_guidance,
|
244 |
+
negative_prompt,
|
245 |
+
)
|
246 |
+
|
247 |
+
# 4. Prepare timesteps
|
248 |
+
scheduler.set_timesteps(num_inference_steps, device=self.device)
|
249 |
+
timesteps = scheduler.timesteps
|
250 |
+
|
251 |
+
# 5. Prepare latent variables
|
252 |
+
num_channels_latents = unet.config.in_channels
|
253 |
+
latents = self.sd_pipeline.prepare_latents(
|
254 |
+
batch_size * num_images_per_prompt,
|
255 |
+
num_channels_latents,
|
256 |
+
height,
|
257 |
+
width,
|
258 |
+
prompt_embeds.dtype,
|
259 |
+
self.device,
|
260 |
+
generator,
|
261 |
+
latents,
|
262 |
+
)
|
263 |
+
|
264 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
265 |
+
extra_step_kwargs = self.sd_pipeline.prepare_extra_step_kwargs(generator, eta)
|
266 |
+
|
267 |
+
# 7. Denoising loop
|
268 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
269 |
+
with self.sd_pipeline.progress_bar(total=num_inference_steps) as progress_bar:
|
270 |
+
for i, t in enumerate(timesteps):
|
271 |
+
# expand the latents if we are doing classifier free guidance
|
272 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
273 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
274 |
+
|
275 |
+
# predict the noise residual
|
276 |
+
noise_pred = unet(
|
277 |
+
latent_model_input,
|
278 |
+
t,
|
279 |
+
encoder_hidden_states=prompt_embeds,
|
280 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
281 |
+
return_dict=False,
|
282 |
+
)[0]
|
283 |
+
|
284 |
+
# perform guidance
|
285 |
+
if do_classifier_free_guidance:
|
286 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
287 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
288 |
+
|
289 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
290 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
291 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
292 |
+
|
293 |
+
# compute the previous noisy sample x_t -> x_t-1
|
294 |
+
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
295 |
+
|
296 |
+
# update progress_bar
|
297 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
298 |
+
progress_bar.update()
|
299 |
+
|
300 |
+
if not output_type == "latent":
|
301 |
+
image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]
|
302 |
+
image, has_nsfw_concept = self.sd_pipeline.run_safety_checker(image, self.device, prompt_embeds.dtype)
|
303 |
+
else:
|
304 |
+
image = latents
|
305 |
+
has_nsfw_concept = None
|
306 |
+
|
307 |
+
if has_nsfw_concept is None:
|
308 |
+
do_denormalize = [True] * image.shape[0]
|
309 |
+
else:
|
310 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
311 |
+
|
312 |
+
image = self.sd_pipeline.image_processor.postprocess(image, output_type=output_type,
|
313 |
+
do_denormalize=do_denormalize)
|
314 |
+
# Offload last model to CPU
|
315 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
316 |
+
self.final_offload_hook.offload()
|
317 |
+
|
318 |
+
if not return_dict:
|
319 |
+
return (image, has_nsfw_concept)
|
320 |
+
|
321 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
322 |
+
|
323 |
+
def sample(self,
|
324 |
+
prompt,
|
325 |
+
height: Optional[int] = None,
|
326 |
+
width: Optional[int] = None,
|
327 |
+
num_inference_steps: int = 50,
|
328 |
+
guidance_scale: float = 7.5,
|
329 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
330 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
331 |
+
output_type: Optional[str] = "pil"):
|
332 |
+
return self.sampling(self.vae, self.unet, self.scheduler,
|
333 |
+
prompt=prompt,
|
334 |
+
height=height, width=width,
|
335 |
+
num_inference_steps=num_inference_steps,
|
336 |
+
guidance_scale=guidance_scale,
|
337 |
+
negative_prompt=negative_prompt,
|
338 |
+
generator=generator,
|
339 |
+
output_type=output_type)
|
340 |
+
|
341 |
+
def sample_lora(self,
|
342 |
+
prompt,
|
343 |
+
height: Optional[int] = None,
|
344 |
+
width: Optional[int] = None,
|
345 |
+
num_inference_steps: int = 50,
|
346 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
347 |
+
output_type: Optional[str] = "pil"):
|
348 |
+
return self.sampling(self.vae_phi, self.unet_phi, self.scheduler,
|
349 |
+
prompt=prompt,
|
350 |
+
height=height, width=width,
|
351 |
+
num_inference_steps=num_inference_steps,
|
352 |
+
guidance_scale=self.guidance_scale_lora,
|
353 |
+
generator=generator,
|
354 |
+
cross_attention_kwargs=self.lora_cross_attention_kwargs,
|
355 |
+
output_type=output_type)
|
356 |
+
|
357 |
+
def encode2latent(self, images):
|
358 |
+
images = (2 * images - 1).clamp(-1.0, 1.0) # images: [B, 3, H, W]
|
359 |
+
# encode images
|
360 |
+
latents = self.vae.encode(images).latent_dist.sample()
|
361 |
+
latents = self.vae.config.scaling_factor * latents
|
362 |
+
return latents
|
363 |
+
|
364 |
+
def get_noise_map(self, noise_pred, guidance_scale=7.5, use_cfg=True):
|
365 |
+
if use_cfg:
|
366 |
+
noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
|
367 |
+
noise_map = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
|
368 |
+
return noise_map
|
369 |
+
else:
|
370 |
+
return noise_pred
|
371 |
+
|
372 |
+
def train_phi_model(self,
|
373 |
+
pred_rgb: torch.Tensor,
|
374 |
+
new_timesteps: bool = False,
|
375 |
+
as_latent: bool = False):
|
376 |
+
# interp to 512x512 to be fed into vae.
|
377 |
+
if as_latent:
|
378 |
+
latents = pred_rgb
|
379 |
+
else:
|
380 |
+
pred_rgb_ = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
381 |
+
# encode image into latents with vae, requires grad!
|
382 |
+
latents = self.encode2latent(pred_rgb_)
|
383 |
+
|
384 |
+
# get phi particles
|
385 |
+
indices = torch.randperm(latents.size(0))
|
386 |
+
latents_phi = latents[indices[:self.phi_n_particle]]
|
387 |
+
latents_phi = latents_phi.detach()
|
388 |
+
|
389 |
+
# get timestep
|
390 |
+
if new_timesteps:
|
391 |
+
t = torch.randint(0, self.num_train_timesteps, (1,), device=self.device)
|
392 |
+
else:
|
393 |
+
t = self.t
|
394 |
+
|
395 |
+
noise = torch.randn_like(latents_phi)
|
396 |
+
noisy_latents = self.scheduler.add_noise(latents_phi, noise, t)
|
397 |
+
|
398 |
+
if self.scheduler.config.prediction_type == "epsilon":
|
399 |
+
target = noise
|
400 |
+
elif self.scheduler.config.prediction_type == "v_prediction":
|
401 |
+
target = self.scheduler.get_velocity(latents_phi, noise, t)
|
402 |
+
else:
|
403 |
+
raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")
|
404 |
+
|
405 |
+
# predict the noise residual and compute loss
|
406 |
+
noise_pred = self.unet_phi(
|
407 |
+
noisy_latents, t,
|
408 |
+
encoder_hidden_states=self.text_embeddings_phi,
|
409 |
+
cross_attention_kwargs=self.lora_cross_attention_kwargs,
|
410 |
+
).sample
|
411 |
+
|
412 |
+
return F.mse_loss(noise_pred, target, reduction="mean")
|
413 |
+
|
414 |
+
def train_phi_model_refl(self,
|
415 |
+
pred_rgb: torch.Tensor,
|
416 |
+
weight: float = 1,
|
417 |
+
new_timesteps: bool = True):
|
418 |
+
# interp to 512x512 to be fed into vae.
|
419 |
+
pred_rgb_ = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
420 |
+
# encode image into latents with vae, requires grad!
|
421 |
+
latents = self.encode2latent(pred_rgb_)
|
422 |
+
|
423 |
+
# get phi particles
|
424 |
+
indices = torch.randperm(latents.size(0))
|
425 |
+
latents_phi = latents[indices[:self.phi_n_particle]]
|
426 |
+
latents_phi = latents_phi.detach()
|
427 |
+
|
428 |
+
# get timestep
|
429 |
+
if new_timesteps:
|
430 |
+
t = torch.randint(0, self.num_train_timesteps, (1,), device=self.device)
|
431 |
+
else:
|
432 |
+
t = self.t
|
433 |
+
|
434 |
+
noise = torch.randn_like(latents_phi)
|
435 |
+
noisy_latents = self.scheduler.add_noise(latents_phi, noise, t)
|
436 |
+
|
437 |
+
if self.scheduler.config.prediction_type == "epsilon":
|
438 |
+
target = noise
|
439 |
+
elif self.scheduler.config.prediction_type == "v_prediction":
|
440 |
+
target = self.scheduler.get_velocity(latents_phi, noise, t)
|
441 |
+
else:
|
442 |
+
raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")
|
443 |
+
|
444 |
+
# predict the noise residual and compute loss
|
445 |
+
noise_pred = self.unet_phi(
|
446 |
+
noisy_latents, t,
|
447 |
+
encoder_hidden_states=self.text_embedd_cond,
|
448 |
+
cross_attention_kwargs=self.lora_cross_attention_kwargs,
|
449 |
+
).sample
|
450 |
+
|
451 |
+
rewards = torch.tensor(weight, dtype=torch.float32, device=self.device)
|
452 |
+
return rewards * F.mse_loss(noise_pred, target, reduction="mean")
|
453 |
+
|
454 |
+
def schedule_timestep(self, step):
|
455 |
+
min_step = int(self.num_train_timesteps * self.t_range[0])
|
456 |
+
max_step = int(self.num_train_timesteps * self.t_range[1])
|
457 |
+
if self.t_schedule == 'randint':
|
458 |
+
t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
|
459 |
+
elif re.match(r"max_([\d.]+)_(\d+)", self.t_schedule):
|
460 |
+
# Anneal time schedule
|
461 |
+
# e.g: t_schedule == 'max_0.5_200'
|
462 |
+
# [0.02, 0.98] -> [0.02, 0.5] after 200 steps
|
463 |
+
tag, t_val, step_upd = str(self.t_schedule).split('_')
|
464 |
+
t_val, step_upd = float(t_val), int(step_upd)
|
465 |
+
if step >= step_upd:
|
466 |
+
max_step = int(self.num_train_timesteps * t_val)
|
467 |
+
t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
|
468 |
+
elif re.match(r"min_([\d.]+)_(\d+)", self.t_schedule):
|
469 |
+
# Anneal time schedule
|
470 |
+
# e.g: t_schedule == 'min_0.5_200'
|
471 |
+
# [0.02, 0.98] -> [0.5, 0.98] after 200 steps
|
472 |
+
tag, t_val, step_upd = str(self.t_schedule).split('_')
|
473 |
+
t_val, step_upd = float(t_val), int(step_upd)
|
474 |
+
if step >= step_upd:
|
475 |
+
min_step = int(self.num_train_timesteps * t_val)
|
476 |
+
t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
|
477 |
+
else:
|
478 |
+
raise NotImplementedError(f"{self.t_schedule} is not support.")
|
479 |
+
return t
|
480 |
+
|
481 |
+
def set_text_embeddings(self, prompt, negative_prompt, do_classifier_free_guidance):
|
482 |
+
if self.text_embeddings is not None:
|
483 |
+
return
|
484 |
+
|
485 |
+
# encode text prompt
|
486 |
+
text_embeddings, text_embeddings_uncond, text_embeddings_cond = \
|
487 |
+
self.encode_prompt(prompt, self.device, do_classifier_free_guidance, negative_prompt=negative_prompt)
|
488 |
+
|
489 |
+
# set pretrained model text embedding
|
490 |
+
text_embeddings_uncond, text_embeddings_cond = text_embeddings.chunk(2)
|
491 |
+
self.text_embedd_uncond, self.text_embedd_cond = text_embeddings_uncond, text_embeddings_cond
|
492 |
+
text_embeddings_unconds = text_embeddings_uncond.repeat_interleave(self.vsd_n_particle, dim=0)
|
493 |
+
text_embeddings_conds = text_embeddings_cond.repeat_interleave(self.vsd_n_particle, dim=0)
|
494 |
+
text_embeddings = torch.cat([text_embeddings_unconds, text_embeddings_conds])
|
495 |
+
self.text_embeddings = text_embeddings
|
496 |
+
|
497 |
+
# set phi model text embedding
|
498 |
+
self.text_embeddings_phi = text_embeddings_cond.repeat_interleave(self.phi_n_particle, dim=0)
|
499 |
+
|
500 |
+
def x_augment(self, x: torch.Tensor, img_size: int = 512):
|
501 |
+
augment_compose = transforms.Compose([
|
502 |
+
transforms.RandomPerspective(distortion_scale=0.5, p=0.7),
|
503 |
+
transforms.RandomCrop(size=(img_size, img_size), pad_if_needed=True, padding_mode='reflect')
|
504 |
+
])
|
505 |
+
return augment_compose(x)
|
506 |
+
|
507 |
+
def variational_score_distillation(self,
|
508 |
+
pred_rgb: torch.Tensor,
|
509 |
+
step: int,
|
510 |
+
prompt: Union[List, str],
|
511 |
+
negative_prompt: Union[List, str] = None,
|
512 |
+
grad_scale: float = 1.0,
|
513 |
+
enhance_particle: bool = False,
|
514 |
+
im_size: int = 512,
|
515 |
+
as_latent: bool = False):
|
516 |
+
bz = pred_rgb.shape[0]
|
517 |
+
|
518 |
+
# data enhancement for the input particles
|
519 |
+
pred_rgb = self.x_augment(pred_rgb, im_size) if enhance_particle else pred_rgb
|
520 |
+
|
521 |
+
# interp to 512x512 to be fed into vae.
|
522 |
+
if as_latent:
|
523 |
+
latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
|
524 |
+
else:
|
525 |
+
pred_rgb_ = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
526 |
+
# encode image into latents with vae, requires grad!
|
527 |
+
# latents = self.encode2latent(pred_rgb_)
|
528 |
+
latent_list = [self.encode2latent(pred_rgb_[i].unsqueeze(0)) for i in range(bz)]
|
529 |
+
latents = torch.cat(latent_list, dim=0)
|
530 |
+
latents = latents.to(self.device)
|
531 |
+
|
532 |
+
# random sample n_particle_vsd particles from latents
|
533 |
+
latents_vsd = latents[torch.randperm(bz)[:self.vsd_n_particle]]
|
534 |
+
|
535 |
+
# encode input prompt
|
536 |
+
do_classifier_free_guidance = True
|
537 |
+
self.set_text_embeddings(prompt, negative_prompt, do_classifier_free_guidance)
|
538 |
+
text_embeddings = self.text_embeddings
|
539 |
+
|
540 |
+
# timestep a.k.a noise level
|
541 |
+
self.t = self.schedule_timestep(step)
|
542 |
+
|
543 |
+
# predict the noise residual with unet, stop gradient
|
544 |
+
with torch.no_grad():
|
545 |
+
# add noise
|
546 |
+
noise = torch.randn_like(latents_vsd)
|
547 |
+
latents_noisy = self.scheduler.add_noise(latents_vsd, noise, self.t)
|
548 |
+
# pred noise
|
549 |
+
latent_model_input = torch.cat([latents_noisy] * 2) if do_classifier_free_guidance else latents_noisy
|
550 |
+
# pretrained noise prediction network
|
551 |
+
noise_pred_pretrain = self.unet(
|
552 |
+
latent_model_input, self.t,
|
553 |
+
encoder_hidden_states=text_embeddings,
|
554 |
+
cross_attention_kwargs={'scale': 0.0} if self.phi_single else {}
|
555 |
+
).sample
|
556 |
+
|
557 |
+
# use conditional text embeddings in phi_model
|
558 |
+
_, text_embeddings_cond = text_embeddings.chunk(2)
|
559 |
+
# estimated noise prediction network
|
560 |
+
noise_pred_est = self.unet_phi(
|
561 |
+
latents_noisy, self.t,
|
562 |
+
encoder_hidden_states=text_embeddings_cond,
|
563 |
+
cross_attention_kwargs=self.lora_cross_attention_kwargs
|
564 |
+
).sample
|
565 |
+
|
566 |
+
# get pretrained score
|
567 |
+
noise_pred_pretrain = self.get_noise_map(noise_pred_pretrain, self.guidance_scale, use_cfg=True)
|
568 |
+
# get estimated score
|
569 |
+
noise_pred_est = self.get_noise_map(noise_pred_est, self.guidance_scale_lora, use_cfg=False)
|
570 |
+
|
571 |
+
# w(t), sigma_t^2
|
572 |
+
w = (1 - self.alphas[self.t])
|
573 |
+
grad = grad_scale * w * (noise_pred_pretrain - noise_pred_est.detach())
|
574 |
+
grad = torch.nan_to_num(grad)
|
575 |
+
|
576 |
+
# grad clipping for stable training
|
577 |
+
if self.grad_clip_val is not None and self.grad_clip_val > 0:
|
578 |
+
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
|
579 |
+
|
580 |
+
# re-parameterization trick:
|
581 |
+
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
|
582 |
+
target = (latents_vsd - grad).detach()
|
583 |
+
loss_vpsd = 0.5 * F.mse_loss(latents_vsd, target, reduction="sum")
|
584 |
+
|
585 |
+
return loss_vpsd, grad.norm(), latents, self.t
|
svgdreamer/painter/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
+
# Description:
|
4 |
+
|
5 |
+
from .painter_params import (
|
6 |
+
Painter, PainterOptimizer, CosineWithWarmupLRLambda, RandomCoordInit, NaiveCoordInit, SparseCoordInit, get_sdf)
|
7 |
+
from .component_painter_params import CompPainter, CompPainterOptimizer
|
8 |
+
from .loss import xing_loss_fn
|
9 |
+
from .VPSD_pipeline import VectorizedParticleSDSPipeline
|
10 |
+
from .diffusion_pipeline import DiffusionPipeline
|
svgdreamer/painter/component_painter_params.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: ximing
|
3 |
+
# Description: content painter and optimizer
|
4 |
+
# Copyright (c) 2023, XiMing Xing.
|
5 |
+
# License: MIT License
|
6 |
+
|
7 |
+
import copy
|
8 |
+
import math
|
9 |
+
import random
|
10 |
+
import pathlib
|
11 |
+
from typing import Dict, Tuple
|
12 |
+
|
13 |
+
from shapely.geometry.polygon import Polygon
|
14 |
+
from omegaconf import DictConfig
|
15 |
+
import numpy as np
|
16 |
+
import pydiffvg
|
17 |
+
import torch
|
18 |
+
from torch.optim.lr_scheduler import LambdaLR
|
19 |
+
|
20 |
+
from svgdreamer.painter import (SparseCoordInit, RandomCoordInit, NaiveCoordInit, get_sdf)
|
21 |
+
from svgdreamer.libs import get_optimizer
|
22 |
+
|
23 |
+
|
24 |
+
class CompPainter:
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
style: str,
|
29 |
+
target_img: torch.Tensor,
|
30 |
+
canvas_size: Tuple[int, int] = (600, 600),
|
31 |
+
num_segments: int = 4,
|
32 |
+
segment_init: str = 'circle',
|
33 |
+
radius: int = 20,
|
34 |
+
n_grid: int = 32,
|
35 |
+
stroke_width: int = 3,
|
36 |
+
device=None,
|
37 |
+
attn_init: bool = False,
|
38 |
+
attention_map: torch.Tensor = None,
|
39 |
+
attn_prob_tau: float = None,
|
40 |
+
):
|
41 |
+
self.style = style
|
42 |
+
self.device = device
|
43 |
+
self.target_img = target_img
|
44 |
+
|
45 |
+
# curve params
|
46 |
+
self.num_segments = num_segments
|
47 |
+
self.segment_init = segment_init
|
48 |
+
self.radius = radius
|
49 |
+
|
50 |
+
self.canvas_width, self.canvas_height = canvas_size
|
51 |
+
"""pixelart params"""
|
52 |
+
self.n_grid = n_grid # divide the canvas into n grids
|
53 |
+
self.pixel_per_grid = self.canvas_width // self.n_grid
|
54 |
+
"""sketch params"""
|
55 |
+
self.stroke_width = stroke_width
|
56 |
+
"""iconography params"""
|
57 |
+
self.color_ref = None
|
58 |
+
|
59 |
+
self.shapes = [] # record all paths
|
60 |
+
self.shape_groups = []
|
61 |
+
self.cur_shapes, self.cur_shape_groups = [], [] # record the current optimized path
|
62 |
+
self.point_vars = []
|
63 |
+
self.color_vars = []
|
64 |
+
self.width_vars = []
|
65 |
+
|
66 |
+
# init
|
67 |
+
self.attention_map = attention_map
|
68 |
+
self.attn_init = attn_init
|
69 |
+
self.attn_prob_tau = attn_prob_tau
|
70 |
+
self.select_inds = None
|
71 |
+
self.pos_init_method = None
|
72 |
+
|
73 |
+
# background
|
74 |
+
self.para_bg = torch.tensor([1., 1., 1.], requires_grad=False, device=self.device)
|
75 |
+
# count the number of strokes
|
76 |
+
self.strokes_counter = 0 # counts the number of calls to "get_path"
|
77 |
+
|
78 |
+
def attn_init_points(self, num_paths, mask=None):
|
79 |
+
attn_map = (self.attention_map - self.attention_map.min()) / \
|
80 |
+
(self.attention_map.max() - self.attention_map.min())
|
81 |
+
|
82 |
+
attn_map_soft = np.copy(attn_map)
|
83 |
+
attn_map_soft[attn_map > 0] = softmax_t(attn_map[attn_map > 0], tau=self.attn_prob_tau)
|
84 |
+
# for visualizing
|
85 |
+
attn_thresh = np.copy(attn_map_soft)
|
86 |
+
# the probabilities associated with each entry in attn_map
|
87 |
+
attn_map_soft /= np.sum(attn_map_soft)
|
88 |
+
# select points
|
89 |
+
k = num_paths
|
90 |
+
|
91 |
+
# select k points randomly
|
92 |
+
positions = np.where(mask == 1)
|
93 |
+
positions = np.stack(positions, axis=1)
|
94 |
+
np.random.shuffle(positions)
|
95 |
+
positions = positions[:k]
|
96 |
+
|
97 |
+
# note: only use to visual
|
98 |
+
visual_coords = np.copy(positions)
|
99 |
+
|
100 |
+
canvas_coords = np.copy(positions)
|
101 |
+
canvas_coords[:, [0, 1]] = canvas_coords[:, [1, 0]]
|
102 |
+
self.select_inds = canvas_coords
|
103 |
+
|
104 |
+
# for visualizing
|
105 |
+
return attn_thresh, visual_coords
|
106 |
+
|
107 |
+
def component_wise_path_init(self, pred, init_type: str = 'sparse'):
|
108 |
+
if init_type == 'random':
|
109 |
+
self.pos_init_method = RandomCoordInit(self.canvas_height, self.canvas_width)
|
110 |
+
|
111 |
+
elif init_type == 'sparse':
|
112 |
+
assert self.target_img is not None # target_img as GT
|
113 |
+
# when initialized for the first time, the render result is None
|
114 |
+
if pred is None:
|
115 |
+
pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width)
|
116 |
+
# then pred is the render result
|
117 |
+
self.pos_init_method = SparseCoordInit(pred, self.target_img)
|
118 |
+
|
119 |
+
elif init_type == 'naive':
|
120 |
+
assert self.target_img is not None # target_img as GT
|
121 |
+
if pred is None:
|
122 |
+
pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width)
|
123 |
+
self.pos_init_method = NaiveCoordInit(pred, self.target_img)
|
124 |
+
|
125 |
+
else:
|
126 |
+
raise NotImplementedError(f"'{init_type}' is not support.")
|
127 |
+
|
128 |
+
def init_image(self, num_paths=0):
|
129 |
+
self.cur_shapes, self.cur_shape_groups = [], []
|
130 |
+
|
131 |
+
if self.style == 'pixelart': # update path definition
|
132 |
+
num_paths = self.n_grid
|
133 |
+
|
134 |
+
for i in range(num_paths):
|
135 |
+
if self.style == 'iconography':
|
136 |
+
path = self.get_path()
|
137 |
+
self.shapes.append(path)
|
138 |
+
self.cur_shapes.append(path)
|
139 |
+
|
140 |
+
wref, href = self.color_ref
|
141 |
+
wref = max(0, min(int(wref), self.canvas_width - 1))
|
142 |
+
href = max(0, min(int(href), self.canvas_height - 1))
|
143 |
+
fill_color_init = list(self.target_img[0, :, href, wref]) + [1.]
|
144 |
+
fill_color_init = torch.FloatTensor(fill_color_init)
|
145 |
+
path_group = pydiffvg.ShapeGroup(
|
146 |
+
shape_ids=torch.tensor([len(self.shapes) - 1]),
|
147 |
+
fill_color=fill_color_init,
|
148 |
+
stroke_color=None
|
149 |
+
)
|
150 |
+
self.shape_groups.append(path_group)
|
151 |
+
self.cur_shape_groups.append(path_group)
|
152 |
+
|
153 |
+
elif self.style == 'pixelart':
|
154 |
+
fill_color_init = torch.FloatTensor(np.random.uniform(size=[4]))
|
155 |
+
fill_color_init[-1] = 1.0
|
156 |
+
|
157 |
+
for j in range(num_paths):
|
158 |
+
path = self.get_path(coord=[i, j])
|
159 |
+
self.shapes.append(path)
|
160 |
+
self.cur_shapes.append(path)
|
161 |
+
|
162 |
+
path_group = pydiffvg.ShapeGroup(
|
163 |
+
shape_ids=torch.LongTensor([i * num_paths + j]),
|
164 |
+
fill_color=fill_color_init,
|
165 |
+
stroke_color=None,
|
166 |
+
)
|
167 |
+
self.shape_groups.append(path_group)
|
168 |
+
self.cur_shape_groups.append(path_group)
|
169 |
+
|
170 |
+
elif self.style == 'sketch':
|
171 |
+
path = self.get_path()
|
172 |
+
self.shapes.append(path)
|
173 |
+
self.cur_shapes.append(path)
|
174 |
+
|
175 |
+
stroke_color_init = torch.tensor([0.0, 0.0, 0.0, 1.0])
|
176 |
+
path_group = pydiffvg.ShapeGroup(
|
177 |
+
shape_ids=torch.tensor([len(self.shapes) - 1]),
|
178 |
+
fill_color=None,
|
179 |
+
stroke_color=stroke_color_init
|
180 |
+
)
|
181 |
+
self.shape_groups.append(path_group)
|
182 |
+
self.cur_shape_groups.append(path_group)
|
183 |
+
|
184 |
+
elif self.style == 'painting':
|
185 |
+
path = self.get_path()
|
186 |
+
self.shapes.append(path)
|
187 |
+
self.cur_shapes.append(path)
|
188 |
+
|
189 |
+
wref, href = self.color_ref
|
190 |
+
wref = max(0, min(int(wref), self.canvas_width - 1))
|
191 |
+
href = max(0, min(int(href), self.canvas_height - 1))
|
192 |
+
stroke_color_init = list(self.target_img[0, :, href, wref]) + [1.]
|
193 |
+
path_group = pydiffvg.ShapeGroup(
|
194 |
+
shape_ids=torch.tensor([len(self.shapes) - 1]),
|
195 |
+
fill_color=None,
|
196 |
+
stroke_color=stroke_color_init
|
197 |
+
)
|
198 |
+
self.shape_groups.append(path_group)
|
199 |
+
self.cur_shape_groups.append(path_group)
|
200 |
+
|
201 |
+
img = self.render_warp()
|
202 |
+
img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4])
|
203 |
+
img = img.unsqueeze(0) # convert img from HWC to NCHW
|
204 |
+
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
|
205 |
+
return img
|
206 |
+
|
207 |
+
def get_image(self, step: int = 0):
|
208 |
+
img = self.render_warp(step)
|
209 |
+
img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4])
|
210 |
+
img = img.unsqueeze(0) # convert img from HWC to NCHW
|
211 |
+
img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
|
212 |
+
return img
|
213 |
+
|
214 |
+
def get_path(self, coord=None):
|
215 |
+
num_segments = self.num_segments
|
216 |
+
|
217 |
+
points = []
|
218 |
+
if self.style == 'iconography':
|
219 |
+
num_control_points = [2] * num_segments
|
220 |
+
# init segment
|
221 |
+
if self.segment_init == 'circle':
|
222 |
+
radius = self.radius if self.radius is not None else np.random.uniform(0.5, 1)
|
223 |
+
|
224 |
+
if self.attn_init:
|
225 |
+
center = self.select_inds[self.strokes_counter] # shape: (2,)
|
226 |
+
else:
|
227 |
+
center = (random.random(), random.random()) \
|
228 |
+
if self.pos_init_method is None else self.pos_init_method()
|
229 |
+
|
230 |
+
bias = center
|
231 |
+
self.color_ref = copy.deepcopy(bias)
|
232 |
+
|
233 |
+
points = get_circle_coordinates(center, radius, k=num_segments * 3)
|
234 |
+
points = torch.FloatTensor(points)
|
235 |
+
else:
|
236 |
+
if self.attn_init:
|
237 |
+
p0 = self.select_inds[self.strokes_counter]
|
238 |
+
else:
|
239 |
+
p0 = self.pos_init_method()
|
240 |
+
|
241 |
+
self.color_ref = copy.deepcopy(p0)
|
242 |
+
points.append(p0)
|
243 |
+
for j in range(num_segments):
|
244 |
+
radius = self.radius
|
245 |
+
p1 = (p0[0] + radius * np.random.uniform(-0.5, 0.5),
|
246 |
+
p0[1] + radius * np.random.uniform(-0.5, 0.5))
|
247 |
+
p2 = (p1[0] + radius * np.random.uniform(-0.5, 0.5),
|
248 |
+
p1[1] + radius * np.random.uniform(-0.5, 0.5))
|
249 |
+
p3 = (p2[0] + radius * np.random.uniform(-0.5, 0.5),
|
250 |
+
p2[1] + radius * np.random.uniform(-0.5, 0.5))
|
251 |
+
points.append(p1)
|
252 |
+
points.append(p2)
|
253 |
+
if j < num_segments - 1:
|
254 |
+
points.append(p3)
|
255 |
+
p0 = p3
|
256 |
+
points = torch.FloatTensor(points)
|
257 |
+
|
258 |
+
path = pydiffvg.Path(num_control_points=torch.LongTensor(num_control_points),
|
259 |
+
points=points,
|
260 |
+
stroke_width=torch.tensor(0.0),
|
261 |
+
is_closed=True)
|
262 |
+
elif self.style in ['sketch', 'painting', 'ink']:
|
263 |
+
num_control_points = torch.zeros(num_segments, dtype=torch.long) + 2
|
264 |
+
points = []
|
265 |
+
|
266 |
+
if self.attn_init:
|
267 |
+
p0 = self.select_inds[self.strokes_counter]
|
268 |
+
else:
|
269 |
+
p0 = (random.random(), random.random()) \
|
270 |
+
if self.pos_init_method is None else self.pos_init_method()
|
271 |
+
|
272 |
+
self.color_ref = copy.deepcopy(p0)
|
273 |
+
|
274 |
+
points.append(p0)
|
275 |
+
for j in range(num_segments):
|
276 |
+
radius = 0.1
|
277 |
+
p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
|
278 |
+
p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5))
|
279 |
+
p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5))
|
280 |
+
points.append(p1)
|
281 |
+
points.append(p2)
|
282 |
+
points.append(p3)
|
283 |
+
p0 = p3
|
284 |
+
points = torch.tensor(points).to(self.device)
|
285 |
+
|
286 |
+
if not self.attn_init:
|
287 |
+
points[:, 0] *= self.canvas_width
|
288 |
+
points[:, 1] *= self.canvas_height
|
289 |
+
|
290 |
+
path = pydiffvg.Path(num_control_points=torch.LongTensor(num_control_points),
|
291 |
+
points=points,
|
292 |
+
stroke_width=torch.tensor(self.stroke_width),
|
293 |
+
is_closed=False)
|
294 |
+
elif self.style == 'pixelart':
|
295 |
+
x = coord[0] * self.pixel_per_grid
|
296 |
+
y = coord[1] * self.pixel_per_grid
|
297 |
+
points = torch.FloatTensor([
|
298 |
+
[x, y],
|
299 |
+
[x + self.pixel_per_grid, y],
|
300 |
+
[x + self.pixel_per_grid, y + self.pixel_per_grid],
|
301 |
+
[x, y + self.pixel_per_grid]
|
302 |
+
]).to(self.device)
|
303 |
+
path = pydiffvg.Polygon(points=points,
|
304 |
+
stroke_width=torch.tensor(0.0),
|
305 |
+
is_closed=True)
|
306 |
+
|
307 |
+
self.strokes_counter += 1
|
308 |
+
return path
|
309 |
+
|
310 |
+
def clip_curve_shape(self):
|
311 |
+
for group in self.shape_groups:
|
312 |
+
group.fill_color.data.clamp_(0.0, 1.0)
|
313 |
+
|
314 |
+
def reinitialize_paths(self,
|
315 |
+
reinit_path: bool = False,
|
316 |
+
opacity_threshold: float = None,
|
317 |
+
area_threshold: float = None,
|
318 |
+
fpath: pathlib.Path = None):
|
319 |
+
"""
|
320 |
+
reinitialize paths, also known as 'Reinitializing paths' in VectorFusion paper.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
reinit_path: whether to reinitialize paths or not.
|
324 |
+
opacity_threshold: Threshold of opacity.
|
325 |
+
area_threshold: Threshold of the closed polygon area.
|
326 |
+
fpath: The path to save the reinitialized SVG.
|
327 |
+
"""
|
328 |
+
if self.style == 'iconography' and reinit_path:
|
329 |
+
# re-init by opacity_threshold
|
330 |
+
select_path_ids_by_opc = []
|
331 |
+
if opacity_threshold != 0 and opacity_threshold is not None:
|
332 |
+
def get_keys_below_threshold(my_dict, threshold):
|
333 |
+
keys_below_threshold = [key for key, value in my_dict.items() if value < threshold]
|
334 |
+
return keys_below_threshold
|
335 |
+
|
336 |
+
opacity_record_ = {group.shape_ids.item(): group.fill_color.data[-1].item()
|
337 |
+
for group in self.cur_shape_groups}
|
338 |
+
# print("-> opacity_record: ", opacity_record_)
|
339 |
+
print("-> opacity_record: ", [f"{k}: {v:.3f}" for k, v in opacity_record_.items()])
|
340 |
+
select_path_ids_by_opc = get_keys_below_threshold(opacity_record_, opacity_threshold)
|
341 |
+
print("select_path_ids_by_opc: ", select_path_ids_by_opc)
|
342 |
+
|
343 |
+
# remove path by area_threshold
|
344 |
+
select_path_ids_by_area = []
|
345 |
+
if area_threshold != 0 and area_threshold is not None:
|
346 |
+
area_records = [Polygon(shape.points.detach().numpy()).area for shape in self.cur_shapes]
|
347 |
+
# print("-> area_records: ", area_records)
|
348 |
+
print("-> area_records: ", ['%.2f' % i for i in area_records])
|
349 |
+
for i, shape in enumerate(self.cur_shapes):
|
350 |
+
if Polygon(shape.points.detach().numpy()).area < area_threshold:
|
351 |
+
select_path_ids_by_area.append(shape.id)
|
352 |
+
print("select_path_ids_by_area: ", select_path_ids_by_area)
|
353 |
+
|
354 |
+
# re-init paths
|
355 |
+
reinit_union = list(set(select_path_ids_by_opc + select_path_ids_by_area))
|
356 |
+
if len(reinit_union) > 0:
|
357 |
+
for i, path in enumerate(self.cur_shapes):
|
358 |
+
if path.id in reinit_union:
|
359 |
+
self.cur_shapes[i] = self.get_path()
|
360 |
+
for i, group in enumerate(self.cur_shape_groups):
|
361 |
+
shp_ids = group.shape_ids.cpu().numpy().tolist()
|
362 |
+
if set(shp_ids).issubset(reinit_union):
|
363 |
+
fill_color_init = torch.FloatTensor(np.random.uniform(size=[4]))
|
364 |
+
fill_color_init[-1] = np.random.uniform(0.7, 1)
|
365 |
+
stroke_color_init = torch.FloatTensor(np.random.uniform(size=[4]))
|
366 |
+
self.cur_shape_groups[i] = pydiffvg.ShapeGroup(
|
367 |
+
shape_ids=torch.tensor(list(shp_ids)),
|
368 |
+
fill_color=fill_color_init,
|
369 |
+
stroke_color=stroke_color_init)
|
370 |
+
# save reinit svg
|
371 |
+
self.save_svg(fpath)
|
372 |
+
|
373 |
+
print("-" * 40)
|
374 |
+
|
375 |
+
def render_warp(self, seed=0):
|
376 |
+
scene_args = pydiffvg.RenderFunction.serialize_scene(
|
377 |
+
self.canvas_width, self.canvas_height, self.shapes, self.shape_groups
|
378 |
+
)
|
379 |
+
_render = pydiffvg.RenderFunction.apply
|
380 |
+
img = _render(self.canvas_width, # width
|
381 |
+
self.canvas_height, # height
|
382 |
+
2, # num_samples_x
|
383 |
+
2, # num_samples_y
|
384 |
+
seed, # seed
|
385 |
+
None,
|
386 |
+
*scene_args)
|
387 |
+
return img
|
388 |
+
|
389 |
+
def calc_distance_weight(self, loss_weight_keep):
|
390 |
+
shapes_forsdf = copy.deepcopy(self.cur_shapes)
|
391 |
+
shape_groups_forsdf = copy.deepcopy(self.cur_shape_groups)
|
392 |
+
for si in shapes_forsdf:
|
393 |
+
si.stroke_width = torch.FloatTensor([0]).to(self.device)
|
394 |
+
for sg_idx, sgi in enumerate(shape_groups_forsdf):
|
395 |
+
sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(self.device)
|
396 |
+
sgi.shape_ids = torch.LongTensor([sg_idx]).to(self.device)
|
397 |
+
|
398 |
+
sargs_forsdf = pydiffvg.RenderFunction.serialize_scene(
|
399 |
+
self.canvas_width, self.canvas_height, shapes_forsdf, shape_groups_forsdf
|
400 |
+
)
|
401 |
+
_render = pydiffvg.RenderFunction.apply
|
402 |
+
with torch.no_grad():
|
403 |
+
im_forsdf = _render(self.canvas_width, # width
|
404 |
+
self.canvas_height, # height
|
405 |
+
2, # num_samples_x
|
406 |
+
2, # num_samples_y
|
407 |
+
0, # seed
|
408 |
+
None,
|
409 |
+
*sargs_forsdf)
|
410 |
+
|
411 |
+
# use alpha channel is a trick to get 0-1 image
|
412 |
+
im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy()
|
413 |
+
loss_weight = get_sdf(im_forsdf, normalize='to1')
|
414 |
+
loss_weight += loss_weight_keep
|
415 |
+
loss_weight = np.clip(loss_weight, 0, 1)
|
416 |
+
loss_weight = torch.FloatTensor(loss_weight).to(self.device)
|
417 |
+
return loss_weight
|
418 |
+
|
419 |
+
def set_points_parameters(self, id_delta=0):
|
420 |
+
self.point_vars = []
|
421 |
+
for i, path in enumerate(self.cur_shapes):
|
422 |
+
path.id = i + id_delta # set point id
|
423 |
+
path.points.requires_grad = True
|
424 |
+
self.point_vars.append(path.points)
|
425 |
+
|
426 |
+
def get_point_params(self):
|
427 |
+
return self.point_vars
|
428 |
+
|
429 |
+
def set_color_parameters(self):
|
430 |
+
self.color_vars = []
|
431 |
+
for i, group in enumerate(self.cur_shape_groups):
|
432 |
+
if group.fill_color is not None:
|
433 |
+
group.fill_color.requires_grad = True
|
434 |
+
self.color_vars.append(group.fill_color)
|
435 |
+
if group.stroke_color is not None:
|
436 |
+
group.stroke_color.requires_grad = True
|
437 |
+
self.color_vars.append(group.stroke_color)
|
438 |
+
|
439 |
+
def get_color_params(self):
|
440 |
+
return self.color_vars
|
441 |
+
|
442 |
+
def set_width_parameters(self):
|
443 |
+
# stroke`s width optimization
|
444 |
+
self.width_vars = []
|
445 |
+
for i, path in enumerate(self.shapes):
|
446 |
+
path.stroke_width.requires_grad = True
|
447 |
+
self.width_vars.append(path.stroke_width)
|
448 |
+
|
449 |
+
def get_width_params(self):
|
450 |
+
return self.width_vars
|
451 |
+
|
452 |
+
def save_svg(self, fpath):
|
453 |
+
pydiffvg.save_svg(f'{fpath}',
|
454 |
+
self.canvas_width,
|
455 |
+
self.canvas_height,
|
456 |
+
self.shapes,
|
457 |
+
self.shape_groups)
|
458 |
+
|
459 |
+
def load_svg(self, path_svg):
|
460 |
+
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
|
461 |
+
return canvas_width, canvas_height, shapes, shape_groups
|
462 |
+
|
463 |
+
|
464 |
+
def softmax_t(x, tau=0.2):
|
465 |
+
e_x = np.exp(x / tau)
|
466 |
+
return e_x / e_x.sum()
|
467 |
+
|
468 |
+
|
469 |
+
def get_circle_coordinates(center, radius, k):
|
470 |
+
coordinates = []
|
471 |
+
cx, cy = center
|
472 |
+
angle = 2 * math.pi / k
|
473 |
+
|
474 |
+
for i in range(k):
|
475 |
+
theta = i * angle # cur angle
|
476 |
+
x = cx + radius * math.cos(theta) # x
|
477 |
+
y = cy + radius * math.sin(theta) # y
|
478 |
+
coordinates.append((x, y))
|
479 |
+
|
480 |
+
return coordinates
|
481 |
+
|
482 |
+
|
483 |
+
class LinearDecayLRLambda:
|
484 |
+
|
485 |
+
def __init__(self, init_lr, keep_ratio, decay_every, decay_ratio):
|
486 |
+
self.init_lr = init_lr
|
487 |
+
self.keep_ratio = keep_ratio
|
488 |
+
self.decay_every = decay_every
|
489 |
+
self.decay_ratio = decay_ratio
|
490 |
+
|
491 |
+
def __call__(self, n):
|
492 |
+
if n < self.keep_ratio * self.decay_every:
|
493 |
+
return self.init_lr
|
494 |
+
|
495 |
+
decay_time = n // self.decay_every
|
496 |
+
decay_step = n % self.decay_every
|
497 |
+
lr_s = self.decay_ratio ** decay_time
|
498 |
+
lr_e = self.decay_ratio ** (decay_time + 1)
|
499 |
+
r = decay_step / self.decay_every
|
500 |
+
lr = lr_s * (1 - r) + lr_e * r
|
501 |
+
return lr
|
502 |
+
|
503 |
+
|
504 |
+
class CompPainterOptimizer:
|
505 |
+
|
506 |
+
def __init__(self,
|
507 |
+
renderer: CompPainter,
|
508 |
+
style: str,
|
509 |
+
num_iter: int,
|
510 |
+
lr_config: DictConfig,
|
511 |
+
optim_bg: bool = False):
|
512 |
+
self.renderer = renderer
|
513 |
+
self.style = style
|
514 |
+
self.num_iter = num_iter
|
515 |
+
self.lr_config = lr_config
|
516 |
+
schedule_cfg = self.lr_config.schedule
|
517 |
+
self.optim_bg = optim_bg
|
518 |
+
|
519 |
+
if style == 'iconography':
|
520 |
+
self.optim_point, self.optim_color, self.optim_width = True, True, False
|
521 |
+
self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio,
|
522 |
+
self.num_iter, schedule_cfg.decay_ratio)
|
523 |
+
if style == 'pixelart':
|
524 |
+
self.optim_point, self.optim_color, self.optim_width = False, True, False
|
525 |
+
self.point_lr_lambda = None
|
526 |
+
if style == 'sketch':
|
527 |
+
self.optim_point, self.optim_color, self.optim_width = True, False, False
|
528 |
+
self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio,
|
529 |
+
self.num_iter, schedule_cfg.decay_ratio)
|
530 |
+
if style == 'ink':
|
531 |
+
self.optim_point, self.optim_color, self.optim_width = True, False, True
|
532 |
+
self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio,
|
533 |
+
self.num_iter, schedule_cfg.decay_ratio)
|
534 |
+
if style == 'painting':
|
535 |
+
self.optim_point, self.optim_color, self.optim_width = True, True, True
|
536 |
+
self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio,
|
537 |
+
self.num_iter, schedule_cfg.decay_ratio)
|
538 |
+
|
539 |
+
self.point_optimizer = None
|
540 |
+
self.color_optimizer = None
|
541 |
+
self.width_optimizer = None
|
542 |
+
self.bg_optimizer = None
|
543 |
+
|
544 |
+
self.point_scheduler = None
|
545 |
+
|
546 |
+
def init_optimizers(self, pid_delta=0):
|
547 |
+
optim_cfg = self.lr_config.optim
|
548 |
+
optim_name = optim_cfg.name
|
549 |
+
|
550 |
+
params = {}
|
551 |
+
if self.optim_point:
|
552 |
+
self.renderer.set_points_parameters(pid_delta)
|
553 |
+
params['point'] = self.renderer.get_point_params()
|
554 |
+
|
555 |
+
if len(params['point']) > 0:
|
556 |
+
self.point_optimizer = get_optimizer(optim_name, params['point'], self.lr_config.point, optim_cfg)
|
557 |
+
if self.point_lr_lambda is not None:
|
558 |
+
self.point_scheduler = LambdaLR(self.point_optimizer, lr_lambda=self.point_lr_lambda, last_epoch=-1)
|
559 |
+
|
560 |
+
if self.optim_color:
|
561 |
+
self.renderer.set_color_parameters()
|
562 |
+
params['color'] = self.renderer.get_color_params()
|
563 |
+
if len(params['color']) > 0:
|
564 |
+
self.color_optimizer = get_optimizer(optim_name, params['color'], self.lr_config.color, optim_cfg)
|
565 |
+
|
566 |
+
if self.optim_width:
|
567 |
+
self.renderer.set_width_parameters()
|
568 |
+
params['width'] = self.renderer.get_width_params()
|
569 |
+
if len(params['width']) > 0:
|
570 |
+
self.width_optimizer = get_optimizer(optim_name, params['width'], self.lr_config.width, optim_cfg)
|
571 |
+
|
572 |
+
if self.optim_bg:
|
573 |
+
self.renderer.para_bg.requires_grad = True
|
574 |
+
self.bg_optimizer = get_optimizer(optim_name, self.renderer.para_bg, self.lr_config.bg, optim_cfg)
|
575 |
+
|
576 |
+
def update_lr(self):
|
577 |
+
if self.point_scheduler is not None:
|
578 |
+
self.point_scheduler.step()
|
579 |
+
|
580 |
+
def zero_grad_(self):
|
581 |
+
if self.point_optimizer is not None:
|
582 |
+
self.point_optimizer.zero_grad()
|
583 |
+
if self.color_optimizer is not None:
|
584 |
+
self.color_optimizer.zero_grad()
|
585 |
+
if self.width_optimizer is not None:
|
586 |
+
self.width_optimizer.zero_grad()
|
587 |
+
if self.bg_optimizer is not None:
|
588 |
+
self.bg_optimizer.zero_grad()
|
589 |
+
|
590 |
+
def step_(self):
|
591 |
+
if self.point_optimizer is not None:
|
592 |
+
self.point_optimizer.step()
|
593 |
+
if self.color_optimizer is not None:
|
594 |
+
self.color_optimizer.step()
|
595 |
+
if self.width_optimizer is not None:
|
596 |
+
self.width_optimizer.step()
|
597 |
+
if self.bg_optimizer is not None:
|
598 |
+
self.bg_optimizer.step()
|
599 |
+
|
600 |
+
def get_lr(self) -> Dict:
|
601 |
+
lr = {}
|
602 |
+
if self.point_optimizer is not None:
|
603 |
+
lr['pnt'] = self.point_optimizer.param_groups[0]['lr']
|
604 |
+
if self.color_optimizer is not None:
|
605 |
+
lr['clr'] = self.color_optimizer.param_groups[0]['lr']
|
606 |
+
if self.width_optimizer is not None:
|
607 |
+
lr['wd'] = self.width_optimizer.param_groups[0]['lr']
|
608 |
+
if self.bg_optimizer is not None:
|
609 |
+
lr['bg'] = self.bg_optimizer.param_groups[0]['lr']
|
610 |
+
return lr
|
svgdreamer/painter/diffusion_pipeline.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) XiMing Xing. All rights reserved.
|
3 |
+
# Author: XiMing Xing
|
4 |
+
# Description:
|
5 |
+
import PIL
|
6 |
+
from PIL import Image
|
7 |
+
from typing import Any, List, Optional, Union, Dict
|
8 |
+
from omegaconf import DictConfig
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from diffusers import StableDiffusionPipeline
|
13 |
+
from diffusers import DDIMScheduler
|
14 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
15 |
+
rescale_noise_cfg, StableDiffusionPipelineOutput)
|
16 |
+
|
17 |
+
from svgdreamer.diffusers_warp import init_StableDiffusion_pipeline
|
18 |
+
from svgdreamer.token2attn.attn_control import AttentionStore
|
19 |
+
from svgdreamer.token2attn.ptp_utils import text_under_image, view_images
|
20 |
+
|
21 |
+
|
22 |
+
class DiffusionPipeline(torch.nn.Module):
|
23 |
+
|
24 |
+
def __init__(self, model_cfg: DictConfig, diffuser_cfg: DictConfig, device: torch.device):
|
25 |
+
super().__init__()
|
26 |
+
self.device = device
|
27 |
+
|
28 |
+
pipe_kwargs = {
|
29 |
+
"device": self.device,
|
30 |
+
"torch_dtype": torch.float32,
|
31 |
+
"local_files_only": not diffuser_cfg.download,
|
32 |
+
"force_download": diffuser_cfg.force_download,
|
33 |
+
"resume_download": diffuser_cfg.resume_download,
|
34 |
+
"ldm_speed_up": model_cfg.ldm_speed_up,
|
35 |
+
"enable_xformers": model_cfg.enable_xformers,
|
36 |
+
"gradient_checkpoint": model_cfg.gradient_checkpoint,
|
37 |
+
"cpu_offload": model_cfg.cpu_offload,
|
38 |
+
"vae_slicing": False
|
39 |
+
}
|
40 |
+
|
41 |
+
# load pretrained model
|
42 |
+
self.sd_pipeline = init_StableDiffusion_pipeline(
|
43 |
+
model_cfg.model_id,
|
44 |
+
custom_pipeline=StableDiffusionPipeline,
|
45 |
+
custom_scheduler=DDIMScheduler,
|
46 |
+
**pipe_kwargs
|
47 |
+
)
|
48 |
+
# disable grads
|
49 |
+
self.sd_pipeline.vae.requires_grad_(False)
|
50 |
+
self.sd_pipeline.text_encoder.requires_grad_(False)
|
51 |
+
self.sd_pipeline.unet.requires_grad_(False)
|
52 |
+
# set components
|
53 |
+
self.vae = self.sd_pipeline.vae
|
54 |
+
self.unet = self.sd_pipeline.unet
|
55 |
+
self.scheduler = self.sd_pipeline.scheduler
|
56 |
+
self.tokenizer = self.sd_pipeline.tokenizer
|
57 |
+
self.text_encoder = self.sd_pipeline.text_encoder
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def encode_prompt(self,
|
61 |
+
prompt,
|
62 |
+
device,
|
63 |
+
do_classifier_free_guidance,
|
64 |
+
negative_prompt=None):
|
65 |
+
# text conditional embed
|
66 |
+
text_inputs = self.tokenizer(
|
67 |
+
prompt,
|
68 |
+
padding="max_length",
|
69 |
+
max_length=self.tokenizer.model_max_length,
|
70 |
+
truncation=True,
|
71 |
+
return_tensors="pt",
|
72 |
+
)
|
73 |
+
prompt_embeds = self.text_encoder(text_inputs.input_ids.to(device))[0]
|
74 |
+
|
75 |
+
if do_classifier_free_guidance:
|
76 |
+
if negative_prompt is None:
|
77 |
+
uncond_tokens = [""]
|
78 |
+
elif isinstance(negative_prompt, str):
|
79 |
+
uncond_tokens = [negative_prompt]
|
80 |
+
else:
|
81 |
+
uncond_tokens = negative_prompt
|
82 |
+
|
83 |
+
# unconditional embed
|
84 |
+
uncond_input = self.tokenizer(
|
85 |
+
uncond_tokens,
|
86 |
+
padding="max_length",
|
87 |
+
max_length=prompt_embeds.shape[1],
|
88 |
+
truncation=True,
|
89 |
+
return_tensors="pt",
|
90 |
+
)
|
91 |
+
negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
92 |
+
|
93 |
+
concat_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
94 |
+
return concat_prompt_embeds, negative_prompt_embeds, prompt_embeds
|
95 |
+
|
96 |
+
return prompt_embeds, None, None
|
97 |
+
|
98 |
+
def register_attention_control(self, controller):
|
99 |
+
attn_procs = {}
|
100 |
+
cross_att_count = 0
|
101 |
+
for name in self.unet.attn_processors.keys():
|
102 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
|
103 |
+
if name.startswith("mid_block"):
|
104 |
+
hidden_size = self.unet.config.block_out_channels[-1]
|
105 |
+
place_in_unet = "mid"
|
106 |
+
elif name.startswith("up_blocks"):
|
107 |
+
block_id = int(name[len("up_blocks.")])
|
108 |
+
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
|
109 |
+
place_in_unet = "up"
|
110 |
+
elif name.startswith("down_blocks"):
|
111 |
+
block_id = int(name[len("down_blocks.")])
|
112 |
+
hidden_size = self.unet.config.block_out_channels[block_id]
|
113 |
+
place_in_unet = "down"
|
114 |
+
else:
|
115 |
+
continue
|
116 |
+
cross_att_count += 1
|
117 |
+
attn_procs[name] = P2PCrossAttnProcessor(
|
118 |
+
controller=controller, place_in_unet=place_in_unet
|
119 |
+
)
|
120 |
+
|
121 |
+
self.unet.set_attn_processor(attn_procs)
|
122 |
+
controller.num_att_layers = cross_att_count
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def aggregate_attention(prompts,
|
126 |
+
attention_store: AttentionStore,
|
127 |
+
res: int,
|
128 |
+
from_where: List[str],
|
129 |
+
is_cross: bool,
|
130 |
+
select: int):
|
131 |
+
if isinstance(prompts, str):
|
132 |
+
prompts = [prompts]
|
133 |
+
assert isinstance(prompts, list)
|
134 |
+
|
135 |
+
out = []
|
136 |
+
attention_maps = attention_store.get_average_attention()
|
137 |
+
num_pixels = res ** 2
|
138 |
+
for location in from_where:
|
139 |
+
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
|
140 |
+
if item.shape[1] == num_pixels:
|
141 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
142 |
+
out.append(cross_maps)
|
143 |
+
out = torch.cat(out, dim=0)
|
144 |
+
out = out.sum(0) / out.shape[0]
|
145 |
+
return out.cpu()
|
146 |
+
|
147 |
+
def get_cross_attention(self,
|
148 |
+
prompts,
|
149 |
+
attention_store: AttentionStore,
|
150 |
+
res: int,
|
151 |
+
from_where: List[str],
|
152 |
+
select: int = 0,
|
153 |
+
save_path=None):
|
154 |
+
tokens = self.tokenizer.encode(prompts[select])
|
155 |
+
decoder = self.tokenizer.decode
|
156 |
+
# shape: [res ** 2, res ** 2, seq_len]
|
157 |
+
attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, True, select)
|
158 |
+
|
159 |
+
images_text = []
|
160 |
+
images = []
|
161 |
+
for i in range(len(tokens)):
|
162 |
+
image = attention_maps[:, :, i]
|
163 |
+
image = 255 * image / image.max()
|
164 |
+
image = image.unsqueeze(-1).expand(*image.shape, 3)
|
165 |
+
image = image.numpy().astype(np.uint8)
|
166 |
+
image = np.array(Image.fromarray(image).resize((256, 256)))
|
167 |
+
images.append(np.copy(image))
|
168 |
+
image = text_under_image(image, decoder(int(tokens[i])))
|
169 |
+
images_text.append(image)
|
170 |
+
image_array = np.stack(images_text, axis=0)
|
171 |
+
view_images(image_array, save_image=True, fp=save_path)
|
172 |
+
|
173 |
+
return attention_maps, tokens
|
174 |
+
|
175 |
+
def get_self_attention_comp(self,
|
176 |
+
prompts,
|
177 |
+
attention_store: AttentionStore,
|
178 |
+
res: int,
|
179 |
+
from_where: List[str],
|
180 |
+
img_size: int = 224,
|
181 |
+
max_com=10,
|
182 |
+
select: int = 0,
|
183 |
+
save_path=None):
|
184 |
+
attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, False, select)
|
185 |
+
attention_maps = attention_maps.numpy().reshape((res ** 2, res ** 2))
|
186 |
+
# shape: [res ** 2, res ** 2]
|
187 |
+
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
|
188 |
+
print(f"self-attention maps: {attention_maps.shape}, "
|
189 |
+
f"u: {u.shape}, "
|
190 |
+
f"s: {s.shape}, "
|
191 |
+
f"vh: {vh.shape}")
|
192 |
+
|
193 |
+
images = []
|
194 |
+
vh_returns = []
|
195 |
+
for i in range(max_com):
|
196 |
+
image = vh[i].reshape(res, res)
|
197 |
+
image = (image - image.min()) / (image.max() - image.min())
|
198 |
+
image = 255 * image
|
199 |
+
|
200 |
+
ret_ = Image.fromarray(image).resize((img_size, img_size), resample=PIL.Image.Resampling.BILINEAR)
|
201 |
+
vh_returns.append(np.array(ret_))
|
202 |
+
|
203 |
+
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
|
204 |
+
image = Image.fromarray(image).resize((256, 256))
|
205 |
+
image = np.array(image)
|
206 |
+
images.append(image)
|
207 |
+
image_array = np.stack(images, axis=0)
|
208 |
+
view_images(image_array, num_rows=max_com // 10, offset_ratio=0,
|
209 |
+
save_image=True, fp=save_path / "self-attn-vh.png")
|
210 |
+
|
211 |
+
return attention_maps, (u, s, vh), np.stack(vh_returns, axis=0)
|
212 |
+
|
213 |
+
def sampling(self,
|
214 |
+
vae,
|
215 |
+
unet,
|
216 |
+
scheduler,
|
217 |
+
prompt: Union[str, List[str]] = None,
|
218 |
+
height: Optional[int] = None,
|
219 |
+
width: Optional[int] = None,
|
220 |
+
controller: AttentionStore = None, # feed attention_store as control of ptp
|
221 |
+
num_inference_steps: int = 50,
|
222 |
+
guidance_scale: float = 7.5,
|
223 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
224 |
+
num_images_per_prompt: Optional[int] = 1,
|
225 |
+
eta: float = 0.0,
|
226 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
227 |
+
latents: Optional[torch.FloatTensor] = None,
|
228 |
+
output_type: Optional[str] = "pil",
|
229 |
+
return_dict: bool = True,
|
230 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
231 |
+
guidance_rescale: float = 0.0):
|
232 |
+
|
233 |
+
# add attention controller
|
234 |
+
self.register_attention_control(controller)
|
235 |
+
|
236 |
+
# 0. Default height and width to unet
|
237 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
238 |
+
height = height or unet.config.sample_size * vae_scale_factor
|
239 |
+
width = width or unet.config.sample_size * vae_scale_factor
|
240 |
+
|
241 |
+
# 2. Define call parameters
|
242 |
+
if prompt is not None and isinstance(prompt, list):
|
243 |
+
batch_size = len(prompt)
|
244 |
+
else:
|
245 |
+
batch_size = 1
|
246 |
+
|
247 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
248 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
249 |
+
# corresponds to doing no classifier free guidance.
|
250 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
251 |
+
|
252 |
+
# 3. Encode input prompt
|
253 |
+
prompt_embeds, _, _ = self.encode_prompt(
|
254 |
+
prompt,
|
255 |
+
self.device,
|
256 |
+
do_classifier_free_guidance,
|
257 |
+
negative_prompt,
|
258 |
+
)
|
259 |
+
|
260 |
+
# 4. Prepare timesteps
|
261 |
+
scheduler.set_timesteps(num_inference_steps, device=self.device)
|
262 |
+
timesteps = scheduler.timesteps
|
263 |
+
|
264 |
+
# 5. Prepare latent variables
|
265 |
+
num_channels_latents = unet.config.in_channels
|
266 |
+
latents = self.sd_pipeline.prepare_latents(
|
267 |
+
batch_size * num_images_per_prompt,
|
268 |
+
num_channels_latents,
|
269 |
+
height,
|
270 |
+
width,
|
271 |
+
prompt_embeds.dtype,
|
272 |
+
self.device,
|
273 |
+
generator,
|
274 |
+
latents,
|
275 |
+
)
|
276 |
+
|
277 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
278 |
+
extra_step_kwargs = self.sd_pipeline.prepare_extra_step_kwargs(generator, eta)
|
279 |
+
|
280 |
+
# 7. Denoising loop
|
281 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
282 |
+
with self.sd_pipeline.progress_bar(total=num_inference_steps) as progress_bar:
|
283 |
+
for i, t in enumerate(timesteps):
|
284 |
+
# expand the latents if we are doing classifier free guidance
|
285 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
286 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
287 |
+
|
288 |
+
# predict the noise residual
|
289 |
+
noise_pred = unet(
|
290 |
+
latent_model_input,
|
291 |
+
t,
|
292 |
+
encoder_hidden_states=prompt_embeds,
|
293 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
294 |
+
return_dict=False,
|
295 |
+
)[0]
|
296 |
+
|
297 |
+
# perform guidance
|
298 |
+
if do_classifier_free_guidance:
|
299 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
300 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
301 |
+
|
302 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
303 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
304 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
305 |
+
|
306 |
+
# compute the previous noisy sample x_t -> x_t-1
|
307 |
+
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
308 |
+
|
309 |
+
# controller callback
|
310 |
+
latents = controller.step_callback(latents)
|
311 |
+
|
312 |
+
# update progress_bar
|
313 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
314 |
+
progress_bar.update()
|
315 |
+
|
316 |
+
if not output_type == "latent":
|
317 |
+
image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]
|
318 |
+
image, has_nsfw_concept = self.sd_pipeline.run_safety_checker(image, self.device, prompt_embeds.dtype)
|
319 |
+
else:
|
320 |
+
image = latents
|
321 |
+
has_nsfw_concept = None
|
322 |
+
|
323 |
+
if has_nsfw_concept is None:
|
324 |
+
do_denormalize = [True] * image.shape[0]
|
325 |
+
else:
|
326 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
327 |
+
|
328 |
+
image = self.sd_pipeline.image_processor.postprocess(image, output_type=output_type,
|
329 |
+
do_denormalize=do_denormalize)
|
330 |
+
# Offload last model to CPU
|
331 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
332 |
+
self.final_offload_hook.offload()
|
333 |
+
|
334 |
+
if not return_dict:
|
335 |
+
return (image, has_nsfw_concept)
|
336 |
+
|
337 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
338 |
+
|
339 |
+
def sample(self,
|
340 |
+
prompt,
|
341 |
+
height: Optional[int] = None,
|
342 |
+
width: Optional[int] = None,
|
343 |
+
controller: AttentionStore = None, # feed attention_store as control of ptp
|
344 |
+
num_inference_steps: int = 50,
|
345 |
+
guidance_scale: float = 7.5,
|
346 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
347 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
348 |
+
output_type: Optional[str] = "pil"):
|
349 |
+
return self.sampling(self.vae, self.unet, self.scheduler,
|
350 |
+
prompt=prompt,
|
351 |
+
height=height, width=width,
|
352 |
+
controller=controller,
|
353 |
+
num_inference_steps=num_inference_steps,
|
354 |
+
guidance_scale=guidance_scale,
|
355 |
+
negative_prompt=negative_prompt,
|
356 |
+
generator=generator,
|
357 |
+
output_type=output_type)
|
358 |
+
|
359 |
+
def encode2latent(self, images):
|
360 |
+
images = (2 * images - 1).clamp(-1.0, 1.0) # images: [B, 3, H, W]
|
361 |
+
# encode images
|
362 |
+
latents = self.vae.encode(images).latent_dist.sample()
|
363 |
+
latents = self.vae.config.scaling_factor * latents
|
364 |
+
return latents
|
365 |
+
|
366 |
+
|
367 |
+
class P2PCrossAttnProcessor:
|
368 |
+
|
369 |
+
def __init__(self, controller, place_in_unet):
|
370 |
+
super().__init__()
|
371 |
+
self.controller = controller
|
372 |
+
self.place_in_unet = place_in_unet
|
373 |
+
|
374 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
375 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
376 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size)
|
377 |
+
|
378 |
+
query = attn.to_q(hidden_states)
|
379 |
+
|
380 |
+
is_cross = encoder_hidden_states is not None
|
381 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
382 |
+
key = attn.to_k(encoder_hidden_states)
|
383 |
+
value = attn.to_v(encoder_hidden_states)
|
384 |
+
|
385 |
+
query = attn.head_to_batch_dim(query)
|
386 |
+
key = attn.head_to_batch_dim(key)
|
387 |
+
value = attn.head_to_batch_dim(value)
|
388 |
+
|
389 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
390 |
+
|
391 |
+
# one line change
|
392 |
+
self.controller(attention_probs, is_cross, self.place_in_unet)
|
393 |
+
|
394 |
+
hidden_states = torch.bmm(attention_probs, value)
|
395 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
396 |
+
|
397 |
+
# linear proj
|
398 |
+
hidden_states = attn.to_out[0](hidden_states)
|
399 |
+
# dropout
|
400 |
+
hidden_states = attn.to_out[1](hidden_states)
|
401 |
+
|
402 |
+
return hidden_states
|