xingxm commited on
Commit
1b7b364
·
1 Parent(s): 33f9237

feat(project): init SVGDreamer

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ImageReward/ImageReward.py +177 -0
  2. ImageReward/ReFL.py +830 -0
  3. ImageReward/__init__.py +3 -0
  4. ImageReward/models/AestheticScore.py +95 -0
  5. ImageReward/models/BLIP/__init__.py +1 -0
  6. ImageReward/models/BLIP/blip.py +70 -0
  7. ImageReward/models/BLIP/blip_pretrain.py +43 -0
  8. ImageReward/models/BLIP/med.py +947 -0
  9. ImageReward/models/BLIP/vit.py +301 -0
  10. ImageReward/models/BLIPScore.py +97 -0
  11. ImageReward/models/CLIPScore.py +78 -0
  12. ImageReward/models/__init__.py +4 -0
  13. ImageReward/utils.py +184 -0
  14. README.md +86 -12
  15. assets/Icon-SydneyOperaHouse/init_p0.svg +0 -0
  16. assets/Icon-SydneyOperaHouse/init_p1.svg +0 -0
  17. assets/Icon-SydneyOperaHouse/init_p2.svg +0 -0
  18. assets/Icon-SydneyOperaHouse/init_p3.svg +0 -0
  19. assets/Icon-SydneyOperaHouse/init_p4.svg +0 -0
  20. assets/Icon-SydneyOperaHouse/init_p5.svg +0 -0
  21. assets/Icon-SydneyOperaHouse/p_0.svg +0 -0
  22. assets/Icon-SydneyOperaHouse/p_1.svg +0 -0
  23. assets/Icon-SydneyOperaHouse/p_2.svg +0 -0
  24. assets/Icon-SydneyOperaHouse/p_3.svg +0 -0
  25. assets/Icon-SydneyOperaHouse/p_4.svg +0 -0
  26. assets/Icon-SydneyOperaHouse/p_5.svg +0 -0
  27. assets/{teaser1.png → illustrate.png} +2 -2
  28. assets/{teaser2.png → teaser_cases.png} +2 -2
  29. assets/{teaser3.png → teaser_more_cases.png} +2 -2
  30. assets/teaser_svg_asset.png +3 -0
  31. conf/config.yaml +54 -0
  32. conf/x/iconography.yaml +188 -0
  33. conf/x/ink.yaml +188 -0
  34. conf/x/lowpoly.yaml +188 -0
  35. conf/x/painting.yaml +188 -0
  36. conf/x/pixelart.yaml +188 -0
  37. conf/x/sketch.yaml +188 -0
  38. svgdreamer.py +42 -0
  39. svgdreamer/__init__.py +6 -0
  40. svgdreamer/diffusers_warp/__init__.py +248 -0
  41. svgdreamer/diffvg_warp/__init__.py +11 -0
  42. svgdreamer/diffvg_warp/diffvg_state.py +299 -0
  43. svgdreamer/libs/__init__.py +8 -0
  44. svgdreamer/libs/logging.py +65 -0
  45. svgdreamer/libs/model_state.py +253 -0
  46. svgdreamer/libs/optim.py +58 -0
  47. svgdreamer/painter/VPSD_pipeline.py +585 -0
  48. svgdreamer/painter/__init__.py +10 -0
  49. svgdreamer/painter/component_painter_params.py +610 -0
  50. svgdreamer/painter/diffusion_pipeline.py +402 -0
ImageReward/ImageReward.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : ImageReward.py
3
+ @Time : 2023/01/28 19:53:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: ImageReward Reward model.
7
+ * Based on CLIP code base and improved-aesthetic-predictor code base
8
+ * https://github.com/openai/CLIP
9
+ * https://github.com/christophschuhmann/improved-aesthetic-predictor
10
+ '''
11
+
12
+ import os
13
+ import torch
14
+ import torch.nn as nn
15
+ from PIL import Image
16
+ from .models.BLIP.blip_pretrain import BLIP_Pretrain
17
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
18
+
19
+ try:
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+ BICUBIC = InterpolationMode.BICUBIC
23
+ except ImportError:
24
+ BICUBIC = Image.BICUBIC
25
+
26
+
27
+ def _convert_image_to_rgb(image):
28
+ return image.convert("RGB")
29
+
30
+
31
+ def _transform(n_px):
32
+ return Compose([
33
+ Resize(n_px, interpolation=BICUBIC),
34
+ CenterCrop(n_px),
35
+ _convert_image_to_rgb,
36
+ ToTensor(),
37
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
38
+ ])
39
+
40
+
41
+ class MLP(nn.Module):
42
+ def __init__(self, input_size):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+
46
+ self.layers = nn.Sequential(
47
+ nn.Linear(self.input_size, 1024),
48
+ # nn.ReLU(),
49
+ nn.Dropout(0.2),
50
+ nn.Linear(1024, 128),
51
+ # nn.ReLU(),
52
+ nn.Dropout(0.2),
53
+ nn.Linear(128, 64),
54
+ # nn.ReLU(),
55
+ nn.Dropout(0.1),
56
+ nn.Linear(64, 16),
57
+ # nn.ReLU(),
58
+ nn.Linear(16, 1)
59
+ )
60
+
61
+ # initial MLP param
62
+ for name, param in self.layers.named_parameters():
63
+ if 'weight' in name:
64
+ nn.init.normal_(param, mean=0.0, std=1.0 / (self.input_size + 1))
65
+ if 'bias' in name:
66
+ nn.init.constant_(param, val=0)
67
+
68
+ def forward(self, input):
69
+ return self.layers(input)
70
+
71
+
72
+ class ImageReward(nn.Module):
73
+ def __init__(self, med_config, device='cpu'):
74
+ super().__init__()
75
+ self.device = device
76
+
77
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
78
+ self.preprocess = _transform(224)
79
+ self.mlp = MLP(768)
80
+
81
+ self.mean = 0.16717362830052426
82
+ self.std = 1.0333394966054072
83
+
84
+ def score_gard(self, prompt_ids, prompt_attention_mask, image):
85
+
86
+ image_embeds = self.blip.visual_encoder(image)
87
+ # text encode cross attention with image
88
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
89
+ text_output = self.blip.text_encoder(prompt_ids,
90
+ attention_mask=prompt_attention_mask,
91
+ encoder_hidden_states=image_embeds,
92
+ encoder_attention_mask=image_atts,
93
+ return_dict=True,
94
+ )
95
+
96
+ txt_features = text_output.last_hidden_state[:, 0, :] # (feature_dim)
97
+ rewards = self.mlp(txt_features)
98
+ rewards = (rewards - self.mean) / self.std
99
+
100
+ return rewards
101
+
102
+ def score(self, prompt, image):
103
+
104
+ if (type(image).__name__ == 'list'):
105
+ _, rewards = self.inference_rank(prompt, image)
106
+ return rewards
107
+
108
+ # text encode
109
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35,
110
+ return_tensors="pt").to(self.device)
111
+
112
+ # image encode
113
+ if isinstance(image, Image.Image):
114
+ pil_image = image
115
+ elif isinstance(image, str):
116
+ if os.path.isfile(image):
117
+ pil_image = Image.open(image)
118
+ else:
119
+ raise TypeError(
120
+ r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
121
+
122
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
123
+ image_embeds = self.blip.visual_encoder(image)
124
+
125
+ # text encode cross attention with image
126
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
127
+ text_output = self.blip.text_encoder(text_input.input_ids,
128
+ attention_mask=text_input.attention_mask,
129
+ encoder_hidden_states=image_embeds,
130
+ encoder_attention_mask=image_atts,
131
+ return_dict=True,
132
+ )
133
+
134
+ txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim)
135
+ rewards = self.mlp(txt_features)
136
+ rewards = (rewards - self.mean) / self.std
137
+
138
+ return rewards.detach().cpu().numpy().item()
139
+
140
+ def inference_rank(self, prompt, generations_list):
141
+
142
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35,
143
+ return_tensors="pt").to(self.device)
144
+
145
+ txt_set = []
146
+ for generation in generations_list:
147
+ # image encode
148
+ if isinstance(generation, Image.Image):
149
+ pil_image = generation
150
+ elif isinstance(generation, str):
151
+ if os.path.isfile(generation):
152
+ pil_image = Image.open(generation)
153
+ else:
154
+ raise TypeError(
155
+ r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.')
156
+
157
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
158
+ image_embeds = self.blip.visual_encoder(image)
159
+
160
+ # text encode cross attention with image
161
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
162
+ text_output = self.blip.text_encoder(text_input.input_ids,
163
+ attention_mask=text_input.attention_mask,
164
+ encoder_hidden_states=image_embeds,
165
+ encoder_attention_mask=image_atts,
166
+ return_dict=True)
167
+ txt_set.append(text_output.last_hidden_state[:, 0, :])
168
+
169
+ txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
170
+ rewards = self.mlp(txt_features) # [image_num, 1]
171
+ rewards = (rewards - self.mean) / self.std
172
+ rewards = torch.squeeze(rewards)
173
+ _, rank = torch.sort(rewards, dim=0, descending=True)
174
+ _, indices = torch.sort(rank, dim=0)
175
+ indices = indices + 1
176
+
177
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
ImageReward/ReFL.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : ReFL.py
3
+ @Time : 2023/05/01 19:36:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: ReFL Algorithm.
7
+ * Based on diffusers code base
8
+ * https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
9
+ '''
10
+
11
+ import argparse
12
+ import logging
13
+ import math
14
+ import os
15
+ import random
16
+ from pathlib import Path
17
+
18
+ import accelerate
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torch.utils.checkpoint
23
+ import transformers
24
+ from accelerate import Accelerator
25
+ from accelerate.logging import get_logger
26
+ from accelerate.utils import ProjectConfiguration, set_seed
27
+ from datasets import load_dataset
28
+ from huggingface_hub import create_repo, upload_folder
29
+ from packaging import version
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from PIL import Image
34
+ import ImageReward as RM
35
+
36
+ from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
37
+
38
+ try:
39
+ from torchvision.transforms import InterpolationMode
40
+
41
+ BICUBIC = InterpolationMode.BICUBIC
42
+ except ImportError:
43
+ BICUBIC = Image.BICUBIC
44
+
45
+ import diffusers
46
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
47
+ from diffusers.optimization import get_scheduler
48
+ from diffusers.training_utils import EMAModel
49
+ from diffusers.utils import check_min_version, deprecate
50
+ from diffusers.utils.import_utils import is_xformers_available
51
+
52
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
53
+ check_min_version("0.16.0.dev0")
54
+
55
+ logger = get_logger(__name__, log_level="INFO")
56
+
57
+ DATASET_NAME_MAPPING = {
58
+ "refl": ("image", "text"),
59
+ }
60
+
61
+
62
+ def parse_args():
63
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
64
+ parser.add_argument(
65
+ "--grad_scale", type=float, default=1e-3, help="Scale divided for grad loss value."
66
+ )
67
+ parser.add_argument(
68
+ "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
69
+ )
70
+ parser.add_argument(
71
+ "--revision",
72
+ type=str,
73
+ default=None,
74
+ required=False,
75
+ help="Revision of pretrained model identifier from huggingface.co/models.",
76
+ )
77
+ parser.add_argument(
78
+ "--dataset_name",
79
+ type=str,
80
+ default=None,
81
+ help=(
82
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
83
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
84
+ " or to a folder containing files that 🤗 Datasets can understand."
85
+ ),
86
+ )
87
+ parser.add_argument(
88
+ "--dataset_config_name",
89
+ type=str,
90
+ default=None,
91
+ help="The config of the Dataset, leave as None if there's only one config.",
92
+ )
93
+ parser.add_argument(
94
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
95
+ )
96
+ parser.add_argument(
97
+ "--caption_column",
98
+ type=str,
99
+ default="text",
100
+ help="The column of the dataset containing a caption or a list of captions.",
101
+ )
102
+ parser.add_argument(
103
+ "--max_train_samples",
104
+ type=int,
105
+ default=None,
106
+ help=(
107
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
108
+ "value if set."
109
+ ),
110
+ )
111
+ parser.add_argument(
112
+ "--validation_prompts",
113
+ type=str,
114
+ default=None,
115
+ nargs="+",
116
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
117
+ )
118
+ parser.add_argument(
119
+ "--output_dir",
120
+ type=str,
121
+ default="checkpoint/refl",
122
+ help="The output directory where the model predictions and checkpoints will be written.",
123
+ )
124
+ parser.add_argument(
125
+ "--cache_dir",
126
+ type=str,
127
+ default=None,
128
+ help="The directory where the downloaded models and datasets will be stored.",
129
+ )
130
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
131
+ parser.add_argument(
132
+ "--resolution",
133
+ type=int,
134
+ default=512,
135
+ help=(
136
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
137
+ " resolution"
138
+ ),
139
+ )
140
+ parser.add_argument(
141
+ "--center_crop",
142
+ default=False,
143
+ action="store_true",
144
+ help=(
145
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
146
+ " cropped. The images will be resized to the resolution first before cropping."
147
+ ),
148
+ )
149
+ parser.add_argument(
150
+ "--random_flip",
151
+ action="store_true",
152
+ help="whether to randomly flip images horizontally",
153
+ )
154
+ parser.add_argument(
155
+ "--train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader."
156
+ )
157
+ parser.add_argument("--num_train_epochs", type=int, default=100)
158
+ parser.add_argument(
159
+ "--max_train_steps",
160
+ type=int,
161
+ default=100,
162
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
163
+ )
164
+ parser.add_argument(
165
+ "--gradient_accumulation_steps",
166
+ type=int,
167
+ default=4,
168
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
169
+ )
170
+ parser.add_argument(
171
+ "--gradient_checkpointing",
172
+ action="store_true",
173
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
174
+ )
175
+ parser.add_argument(
176
+ "--learning_rate",
177
+ type=float,
178
+ default=1e-5,
179
+ help="Initial learning rate (after the potential warmup period) to use.",
180
+ )
181
+ parser.add_argument(
182
+ "--scale_lr",
183
+ action="store_true",
184
+ default=False,
185
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
186
+ )
187
+ parser.add_argument(
188
+ "--lr_scheduler",
189
+ type=str,
190
+ default="constant",
191
+ help=(
192
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
193
+ ' "constant", "constant_with_warmup"]'
194
+ ),
195
+ )
196
+ parser.add_argument(
197
+ "--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
198
+ )
199
+ parser.add_argument(
200
+ "--snr_gamma",
201
+ type=float,
202
+ default=None,
203
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
204
+ "More details here: https://arxiv.org/abs/2303.09556.",
205
+ )
206
+ parser.add_argument(
207
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
208
+ )
209
+ parser.add_argument(
210
+ "--allow_tf32",
211
+ action="store_true",
212
+ help=(
213
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
214
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
215
+ ),
216
+ )
217
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
218
+ parser.add_argument(
219
+ "--non_ema_revision",
220
+ type=str,
221
+ default=None,
222
+ required=False,
223
+ help=(
224
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
225
+ " remote repository specified with --pretrained_model_name_or_path."
226
+ ),
227
+ )
228
+ parser.add_argument(
229
+ "--dataloader_num_workers",
230
+ type=int,
231
+ default=0,
232
+ help=(
233
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
234
+ ),
235
+ )
236
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
237
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
238
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
239
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
240
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
241
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
242
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
243
+ parser.add_argument(
244
+ "--hub_model_id",
245
+ type=str,
246
+ default=None,
247
+ help="The name of the repository to keep in sync with the local `output_dir`.",
248
+ )
249
+ parser.add_argument(
250
+ "--logging_dir",
251
+ type=str,
252
+ default="logs",
253
+ help=(
254
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
255
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
256
+ ),
257
+ )
258
+ parser.add_argument(
259
+ "--mixed_precision",
260
+ type=str,
261
+ default=None,
262
+ choices=["no", "fp16", "bf16"],
263
+ help=(
264
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
265
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
266
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
267
+ ),
268
+ )
269
+ parser.add_argument(
270
+ "--report_to",
271
+ type=str,
272
+ default="tensorboard",
273
+ help=(
274
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
275
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
276
+ ),
277
+ )
278
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
279
+ parser.add_argument(
280
+ "--checkpointing_steps",
281
+ type=int,
282
+ default=100,
283
+ help=(
284
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
285
+ " training using `--resume_from_checkpoint`."
286
+ ),
287
+ )
288
+ parser.add_argument(
289
+ "--checkpoints_total_limit",
290
+ type=int,
291
+ default=None,
292
+ help=(
293
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
294
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
295
+ " for more docs"
296
+ ),
297
+ )
298
+ parser.add_argument(
299
+ "--resume_from_checkpoint",
300
+ type=str,
301
+ default=None,
302
+ help=(
303
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
304
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
305
+ ),
306
+ )
307
+ parser.add_argument(
308
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
309
+ )
310
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
311
+ parser.add_argument(
312
+ "--validation_epochs",
313
+ type=int,
314
+ default=5,
315
+ help="Run validation every X epochs.",
316
+ )
317
+ parser.add_argument(
318
+ "--tracker_project_name",
319
+ type=str,
320
+ default="text2image-refl",
321
+ help=(
322
+ "The `project_name` argument passed to Accelerator.init_trackers for"
323
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
324
+ ),
325
+ )
326
+
327
+ args = parser.parse_args()
328
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
329
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
330
+ args.local_rank = env_local_rank
331
+
332
+ # default to using the same revision for the non-ema model if not specified
333
+ if args.non_ema_revision is None:
334
+ args.non_ema_revision = args.revision
335
+
336
+ return args
337
+
338
+
339
+ class Trainer(object):
340
+
341
+ def __init__(self, pretrained_model_name_or_path, train_data_dir, args):
342
+
343
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
344
+ self.train_data_dir = train_data_dir
345
+
346
+ # Sanity checks
347
+ if args.dataset_name is None and self.train_data_dir is None:
348
+ raise ValueError("Need either a dataset name or a training folder.")
349
+
350
+ if args.non_ema_revision is not None:
351
+ deprecate(
352
+ "non_ema_revision!=None",
353
+ "0.15.0",
354
+ message=(
355
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
356
+ " use `--variant=non_ema` instead."
357
+ ),
358
+ )
359
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
360
+
361
+ accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
362
+
363
+ self.accelerator = Accelerator(
364
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
365
+ mixed_precision=args.mixed_precision,
366
+ log_with=args.report_to,
367
+ logging_dir=logging_dir,
368
+ project_config=accelerator_project_config,
369
+ )
370
+
371
+ # Make one log on every process with the configuration for debugging.
372
+ logging.basicConfig(
373
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
374
+ datefmt="%m/%d/%Y %H:%M:%S",
375
+ level=logging.INFO,
376
+ )
377
+ logger.info(self.accelerator.state, main_process_only=False)
378
+ if self.accelerator.is_local_main_process:
379
+ transformers.utils.logging.set_verbosity_warning()
380
+ diffusers.utils.logging.set_verbosity_info()
381
+ else:
382
+ transformers.utils.logging.set_verbosity_error()
383
+ diffusers.utils.logging.set_verbosity_error()
384
+
385
+ # If passed along, set the training seed now.
386
+ if args.seed is not None:
387
+ set_seed(args.seed)
388
+
389
+ # Handle the repository creation
390
+ if self.accelerator.is_main_process:
391
+ if args.output_dir is not None:
392
+ os.makedirs(args.output_dir, exist_ok=True)
393
+
394
+ if args.push_to_hub:
395
+ self.repo_id = create_repo(
396
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
397
+ ).repo_id
398
+
399
+ # Load scheduler, tokenizer and models.
400
+ self.noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
401
+ tokenizer = CLIPTokenizer.from_pretrained(
402
+ self.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
403
+ )
404
+ self.text_encoder = CLIPTextModel.from_pretrained(
405
+ self.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
406
+ )
407
+ self.vae = AutoencoderKL.from_pretrained(self.pretrained_model_name_or_path, subfolder="vae",
408
+ revision=args.revision)
409
+ self.unet = UNet2DConditionModel.from_pretrained(
410
+ self.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
411
+ )
412
+ self.reward_model = RM.load("ImageReward-v1.0", device=self.accelerator.device)
413
+
414
+ # Freeze vae and text_encoder
415
+ self.vae.requires_grad_(False)
416
+ self.text_encoder.requires_grad_(False)
417
+ self.reward_model.requires_grad_(False)
418
+
419
+ # Create EMA for the unet.
420
+ if args.use_ema:
421
+ self.ema_unet = UNet2DConditionModel.from_pretrained(
422
+ self.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
423
+ )
424
+ self.ema_unet = EMAModel(self.ema_unet.parameters(), model_cls=UNet2DConditionModel,
425
+ model_config=self.ema_unet.config)
426
+
427
+ if args.enable_xformers_memory_efficient_attention:
428
+ if is_xformers_available():
429
+ import xformers
430
+
431
+ xformers_version = version.parse(xformers.__version__)
432
+ if xformers_version == version.parse("0.0.16"):
433
+ logger.warn(
434
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
435
+ )
436
+ self.unet.enable_xformers_memory_efficient_attention()
437
+ else:
438
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
439
+
440
+ # `accelerate` 0.16.0 will have better support for customized saving
441
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
442
+ # create custom saving & loading hooks so that `self.accelerator.save_state(...)` serializes in a nice format
443
+ def save_model_hook(models, weights, output_dir):
444
+ if args.use_ema:
445
+ self.ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
446
+
447
+ for i, model in enumerate(models):
448
+ model.save_pretrained(os.path.join(output_dir, "unet"))
449
+
450
+ # make sure to pop weight so that corresponding model is not saved again
451
+ weights.pop()
452
+
453
+ def load_model_hook(models, input_dir):
454
+ if args.use_ema:
455
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
456
+ self.ema_unet.load_state_dict(load_model.state_dict())
457
+ self.ema_unet.to(self.accelerator.device)
458
+ del load_model
459
+
460
+ for i in range(len(models)):
461
+ # pop models so that they are not loaded again
462
+ model = models.pop()
463
+
464
+ # load diffusers style into model
465
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
466
+ model.register_to_config(**load_model.config)
467
+
468
+ model.load_state_dict(load_model.state_dict())
469
+ del load_model
470
+
471
+ self.accelerator.register_save_state_pre_hook(save_model_hook)
472
+ self.accelerator.register_load_state_pre_hook(load_model_hook)
473
+
474
+ if args.gradient_checkpointing:
475
+ self.unet.enable_gradient_checkpointing()
476
+
477
+ # Enable TF32 for faster training on Ampere GPUs,
478
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
479
+ if args.allow_tf32:
480
+ torch.backends.cuda.matmul.allow_tf32 = True
481
+
482
+ if args.scale_lr:
483
+ args.learning_rate = (
484
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * self.accelerator.num_processes
485
+ )
486
+
487
+ # Initialize the optimizer
488
+ if args.use_8bit_adam:
489
+ try:
490
+ import bitsandbytes as bnb
491
+ except ImportError:
492
+ raise ImportError(
493
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
494
+ )
495
+
496
+ optimizer_cls = bnb.optim.AdamW8bit
497
+ else:
498
+ optimizer_cls = torch.optim.AdamW
499
+
500
+ self.optimizer = optimizer_cls(
501
+ self.unet.parameters(),
502
+ lr=args.learning_rate,
503
+ betas=(args.adam_beta1, args.adam_beta2),
504
+ weight_decay=args.adam_weight_decay,
505
+ eps=args.adam_epsilon,
506
+ )
507
+
508
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
509
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
510
+
511
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
512
+ # download the dataset.
513
+ if args.dataset_name is not None:
514
+ # Downloading and loading a dataset from the hub.
515
+ dataset = load_dataset(
516
+ args.dataset_name,
517
+ args.dataset_config_name,
518
+ cache_dir=args.cache_dir,
519
+ )
520
+ else:
521
+ data_files = {}
522
+ data_files["train"] = self.train_data_dir
523
+ dataset = load_dataset(
524
+ "json",
525
+ data_files=data_files,
526
+ cache_dir=args.cache_dir,
527
+ )
528
+ # See more about loading custom images at
529
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
530
+
531
+ # Preprocessing the datasets.
532
+ # We need to tokenize inputs and targets.
533
+ column_names = dataset["train"].column_names
534
+
535
+ # Get the column names for input/target.
536
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
537
+ if args.image_column is None:
538
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
539
+ else:
540
+ image_column = args.image_column
541
+ if image_column not in column_names:
542
+ raise ValueError(
543
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
544
+ )
545
+ if args.caption_column is None:
546
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
547
+ else:
548
+ caption_column = args.caption_column
549
+ if caption_column not in column_names:
550
+ raise ValueError(
551
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
552
+ )
553
+
554
+ # Preprocessing the datasets.
555
+ # We need to tokenize input captions and transform the images.
556
+ def tokenize_captions(examples, is_train=True):
557
+ captions = []
558
+ for caption in examples[caption_column]:
559
+ if isinstance(caption, str):
560
+ captions.append(caption)
561
+ elif isinstance(caption, (list, np.ndarray)):
562
+ # take a random caption if there are multiple
563
+ captions.append(random.choice(caption) if is_train else caption[0])
564
+ else:
565
+ raise ValueError(
566
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
567
+ )
568
+ inputs = tokenizer(
569
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
570
+ return_tensors="pt"
571
+ )
572
+ return inputs.input_ids
573
+
574
+ def preprocess_train(examples):
575
+ examples["input_ids"] = tokenize_captions(examples)
576
+ examples["rm_input_ids"] = self.reward_model.blip.tokenizer(examples[caption_column], padding='max_length',
577
+ truncation=True, max_length=35,
578
+ return_tensors="pt").input_ids
579
+ examples["rm_attention_mask"] = self.reward_model.blip.tokenizer(examples[caption_column],
580
+ padding='max_length', truncation=True,
581
+ max_length=35,
582
+ return_tensors="pt").attention_mask
583
+ return examples
584
+
585
+ with self.accelerator.main_process_first():
586
+ if args.max_train_samples is not None:
587
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
588
+ # Set the training transforms
589
+ self.train_dataset = dataset["train"].with_transform(preprocess_train)
590
+
591
+ def collate_fn(examples):
592
+ input_ids = torch.stack([example["input_ids"] for example in examples])
593
+ rm_input_ids = torch.stack([example["rm_input_ids"] for example in examples])
594
+ rm_attention_mask = torch.stack([example["rm_attention_mask"] for example in examples])
595
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
596
+ rm_input_ids = rm_input_ids.view(-1, rm_input_ids.shape[-1])
597
+ rm_attention_mask = rm_attention_mask.view(-1, rm_attention_mask.shape[-1])
598
+ return {"input_ids": input_ids, "rm_input_ids": rm_input_ids, "rm_attention_mask": rm_attention_mask}
599
+
600
+ # DataLoaders creation:
601
+ self.train_dataloader = torch.utils.data.DataLoader(
602
+ self.train_dataset,
603
+ shuffle=True,
604
+ collate_fn=collate_fn,
605
+ batch_size=args.train_batch_size,
606
+ num_workers=args.dataloader_num_workers,
607
+ )
608
+
609
+ # Scheduler and math around the number of training steps.
610
+ overrode_max_train_steps = False
611
+ self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
612
+ if args.max_train_steps is None:
613
+ args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
614
+ overrode_max_train_steps = True
615
+
616
+ self.lr_scheduler = get_scheduler(
617
+ args.lr_scheduler,
618
+ optimizer=self.optimizer,
619
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
620
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
621
+ )
622
+
623
+ # Prepare everything with our `self.accelerator`.
624
+ self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
625
+ self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler
626
+ )
627
+
628
+ if args.use_ema:
629
+ self.ema_unet.to(self.accelerator.device)
630
+
631
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
632
+ # as these models are only used for inference, keeping weights in full precision is not required.
633
+ self.weight_dtype = torch.float32
634
+ if self.accelerator.mixed_precision == "fp16":
635
+ self.weight_dtype = torch.float16
636
+ elif self.accelerator.mixed_precision == "bf16":
637
+ self.weight_dtype = torch.bfloat16
638
+
639
+ # Move text_encode and vae to gpu and cast to self.weight_dtype
640
+ self.text_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
641
+ self.vae.to(self.accelerator.device, dtype=self.weight_dtype)
642
+ self.reward_model.to(self.accelerator.device, dtype=self.weight_dtype)
643
+
644
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
645
+ self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps)
646
+ if overrode_max_train_steps:
647
+ args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch
648
+ # Afterwards we recalculate our number of training epochs
649
+ args.num_train_epochs = math.ceil(args.max_train_steps / self.num_update_steps_per_epoch)
650
+
651
+ # We need to initialize the trackers we use, and also store our configuration.
652
+ # The trackers initializes automatically on the main process.
653
+ if self.accelerator.is_main_process:
654
+ tracker_config = dict(vars(args))
655
+ tracker_config.pop("validation_prompts")
656
+ self.accelerator.init_trackers(args.tracker_project_name, tracker_config)
657
+
658
+ def train(self, args):
659
+
660
+ # Train!
661
+ total_batch_size = args.train_batch_size * self.accelerator.num_processes * args.gradient_accumulation_steps
662
+
663
+ logger.info("***** Running training *****")
664
+ logger.info(f" Num examples = {len(self.train_dataset)}")
665
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
666
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
667
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
668
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
669
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
670
+ global_step = 0
671
+ first_epoch = 0
672
+
673
+ # Potentially load in the weights and states from a previous save
674
+ if args.resume_from_checkpoint:
675
+ if args.resume_from_checkpoint != "latest":
676
+ path = os.path.basename(args.resume_from_checkpoint)
677
+ else:
678
+ # Get the most recent checkpoint
679
+ dirs = os.listdir(args.output_dir)
680
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
681
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
682
+ path = dirs[-1] if len(dirs) > 0 else None
683
+
684
+ if path is None:
685
+ self.accelerator.print(
686
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
687
+ )
688
+ args.resume_from_checkpoint = None
689
+ else:
690
+ self.accelerator.print(f"Resuming from checkpoint {path}")
691
+ self.accelerator.load_state(os.path.join(args.output_dir, path))
692
+ global_step = int(path.split("-")[1])
693
+
694
+ resume_global_step = global_step * args.gradient_accumulation_steps
695
+ first_epoch = global_step // self.num_update_steps_per_epoch
696
+ resume_step = resume_global_step % (self.num_update_steps_per_epoch * args.gradient_accumulation_steps)
697
+
698
+ # Only show the progress bar once on each machine.
699
+ progress_bar = tqdm(range(global_step, args.max_train_steps),
700
+ disable=not self.accelerator.is_local_main_process)
701
+ progress_bar.set_description("Steps")
702
+
703
+ for epoch in range(first_epoch, args.num_train_epochs):
704
+ self.unet.train()
705
+ train_loss = 0.0
706
+ for step, batch in enumerate(self.train_dataloader):
707
+ # Skip steps until we reach the resumed step
708
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
709
+ if step % args.gradient_accumulation_steps == 0:
710
+ progress_bar.update(1)
711
+ continue
712
+
713
+ with self.accelerator.accumulate(self.unet):
714
+ encoder_hidden_states = self.text_encoder(batch["input_ids"])[0]
715
+ latents = torch.randn((args.train_batch_size, 4, 64, 64), device=self.accelerator.device)
716
+
717
+ self.noise_scheduler.set_timesteps(40, device=self.accelerator.device)
718
+ timesteps = self.noise_scheduler.timesteps
719
+
720
+ mid_timestep = random.randint(30, 39)
721
+
722
+ for i, t in enumerate(timesteps[:mid_timestep]):
723
+ with torch.no_grad():
724
+ latent_model_input = latents
725
+ latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
726
+ noise_pred = self.unet(
727
+ latent_model_input,
728
+ t,
729
+ encoder_hidden_states=encoder_hidden_states,
730
+ ).sample
731
+ latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
732
+
733
+ latent_model_input = latents
734
+ latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input,
735
+ timesteps[mid_timestep])
736
+ noise_pred = self.unet(
737
+ latent_model_input,
738
+ timesteps[mid_timestep],
739
+ encoder_hidden_states=encoder_hidden_states,
740
+ ).sample
741
+ pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[mid_timestep],
742
+ latents).pred_original_sample.to(self.weight_dtype)
743
+
744
+ pred_original_sample = 1 / self.vae.config.scaling_factor * pred_original_sample
745
+ image = self.vae.decode(pred_original_sample.to(self.weight_dtype)).sample
746
+ image = (image / 2 + 0.5).clamp(0, 1)
747
+
748
+ # image encode
749
+ def _transform():
750
+ return Compose([
751
+ Resize(224, interpolation=BICUBIC),
752
+ CenterCrop(224),
753
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
754
+ ])
755
+
756
+ rm_preprocess = _transform()
757
+ image = rm_preprocess(image).to(self.accelerator.device)
758
+
759
+ rewards = self.reward_model.score_gard(batch["rm_input_ids"], batch["rm_attention_mask"], image)
760
+ loss = F.relu(-rewards + 2)
761
+ loss = loss.mean() * args.grad_scale
762
+
763
+ # Gather the losses across all processes for logging (if we use distributed training).
764
+ avg_loss = self.accelerator.gather(loss.repeat(args.train_batch_size)).mean()
765
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
766
+
767
+ # Backpropagate
768
+ self.accelerator.backward(loss)
769
+ if self.accelerator.sync_gradients:
770
+ self.accelerator.clip_grad_norm_(self.unet.parameters(), args.max_grad_norm)
771
+ self.optimizer.step()
772
+ self.lr_scheduler.step()
773
+ self.optimizer.zero_grad()
774
+
775
+ # Checks if the self.accelerator has performed an optimization step behind the scenes
776
+ if self.accelerator.sync_gradients:
777
+ if args.use_ema:
778
+ self.ema_unet.step(self.unet.parameters())
779
+ progress_bar.update(1)
780
+ global_step += 1
781
+ self.accelerator.log({"train_loss": train_loss}, step=global_step)
782
+ train_loss = 0.0
783
+
784
+ if global_step % args.checkpointing_steps == 0:
785
+ if self.accelerator.is_main_process:
786
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
787
+ self.accelerator.save_state(save_path)
788
+ logger.info(f"Saved state to {save_path}")
789
+
790
+ logs = {"step_loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]}
791
+ progress_bar.set_postfix(**logs)
792
+
793
+ if global_step >= args.max_train_steps:
794
+ break
795
+
796
+ if self.accelerator.is_main_process:
797
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
798
+ if args.use_ema:
799
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
800
+ self.ema_unet.store(self.unet.parameters())
801
+ self.ema_unet.copy_to(self.unet.parameters())
802
+ if args.use_ema:
803
+ # Switch back to the original UNet parameters.
804
+ self.ema_unet.restore(self.unet.parameters())
805
+
806
+ # Create the pipeline using the trained modules and save it.
807
+ self.accelerator.wait_for_everyone()
808
+ if self.accelerator.is_main_process:
809
+ self.unet = self.accelerator.unwrap_model(self.unet)
810
+ if args.use_ema:
811
+ self.ema_unet.copy_to(self.unet.parameters())
812
+
813
+ pipeline = StableDiffusionPipeline.from_pretrained(
814
+ self.pretrained_model_name_or_path,
815
+ text_encoder=self.text_encoder,
816
+ vae=self.vae,
817
+ unet=self.unet,
818
+ revision=args.revision,
819
+ )
820
+ pipeline.save_pretrained(args.output_dir)
821
+
822
+ if args.push_to_hub:
823
+ upload_folder(
824
+ repo_id=self.repo_id,
825
+ folder_path=args.output_dir,
826
+ commit_message="End of training",
827
+ ignore_patterns=["step_*", "epoch_*"],
828
+ )
829
+
830
+ self.accelerator.end_training()
ImageReward/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .utils import *
2
+ from .models import *
3
+ from .ReFL import *
ImageReward/models/AestheticScore.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : AestheticScore.py
3
+ @Time : 2023/02/12 14:54:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: AestheticScore.
7
+ * Based on improved-aesthetic-predictor code base
8
+ * https://github.com/christophschuhmann/improved-aesthetic-predictor
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ import clip
16
+
17
+
18
+ # if you changed the MLP architecture during training, change it also here:
19
+ class MLP(nn.Module):
20
+ def __init__(self, input_size):
21
+ super().__init__()
22
+ self.input_size = input_size
23
+ self.layers = nn.Sequential(
24
+ nn.Linear(self.input_size, 1024),
25
+ # nn.ReLU(),
26
+ nn.Dropout(0.2),
27
+ nn.Linear(1024, 128),
28
+ # nn.ReLU(),
29
+ nn.Dropout(0.2),
30
+ nn.Linear(128, 64),
31
+ # nn.ReLU(),
32
+ nn.Dropout(0.1),
33
+
34
+ nn.Linear(64, 16),
35
+ # nn.ReLU(),
36
+
37
+ nn.Linear(16, 1)
38
+ )
39
+
40
+ def forward(self, x):
41
+ return self.layers(x)
42
+
43
+
44
+ class AestheticScore(nn.Module):
45
+ def __init__(self, download_root, device='cpu'):
46
+ super().__init__()
47
+ self.device = device
48
+ self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
49
+ download_root=download_root)
50
+ self.mlp = MLP(768)
51
+
52
+ if device == "cpu":
53
+ self.clip_model.float()
54
+ else:
55
+ clip.model.convert_weights(
56
+ self.clip_model) # Actually this line is unnecessary since clip by default already on float16
57
+
58
+ # have clip.logit_scale require no grad.
59
+ self.clip_model.logit_scale.requires_grad_(False)
60
+
61
+ def score(self, prompt, image_path):
62
+
63
+ if (type(image_path).__name__ == 'list'):
64
+ _, rewards = self.inference_rank(prompt, image_path)
65
+ return rewards
66
+
67
+ # image encode
68
+ pil_image = Image.open(image_path)
69
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
70
+ image_features = F.normalize(self.clip_model.encode_image(image)).float()
71
+
72
+ # score
73
+ rewards = self.mlp(image_features)
74
+
75
+ return rewards.detach().cpu().numpy().item()
76
+
77
+ def inference_rank(self, prompt, generations_list):
78
+
79
+ img_set = []
80
+ for generations in generations_list:
81
+ # image encode
82
+ img_path = generations
83
+ pil_image = Image.open(img_path)
84
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
85
+ image_features = F.normalize(self.clip_model.encode_image(image))
86
+ img_set.append(image_features)
87
+
88
+ img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
89
+ rewards = self.mlp(img_features)
90
+ rewards = torch.squeeze(rewards)
91
+ _, rank = torch.sort(rewards, dim=0, descending=True)
92
+ _, indices = torch.sort(rank, dim=0)
93
+ indices = indices + 1
94
+
95
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
ImageReward/models/BLIP/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .blip_pretrain import *
ImageReward/models/BLIP/blip.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ import torch
9
+ import os
10
+ from urllib.parse import urlparse
11
+ from timm.models.hub import download_cached_file
12
+ from transformers import BertTokenizer
13
+ from .vit import VisionTransformer, interpolate_pos_embed
14
+
15
+
16
+ def init_tokenizer():
17
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
18
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
19
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
20
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
21
+ return tokenizer
22
+
23
+
24
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
25
+
26
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
27
+ if vit=='base':
28
+ vision_width = 768
29
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
30
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
31
+ drop_path_rate=0 or drop_path_rate
32
+ )
33
+ elif vit=='large':
34
+ vision_width = 1024
35
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
36
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
37
+ drop_path_rate=0.1 or drop_path_rate
38
+ )
39
+ return visual_encoder, vision_width
40
+
41
+
42
+ def is_url(url_or_filename):
43
+ parsed = urlparse(url_or_filename)
44
+ return parsed.scheme in ("http", "https")
45
+
46
+ def load_checkpoint(model,url_or_filename):
47
+ if is_url(url_or_filename):
48
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
49
+ checkpoint = torch.load(cached_file, map_location='cpu')
50
+ elif os.path.isfile(url_or_filename):
51
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
52
+ else:
53
+ raise RuntimeError('checkpoint url or path is invalid')
54
+
55
+ state_dict = checkpoint['model']
56
+
57
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
58
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
59
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
60
+ model.visual_encoder_m)
61
+ for key in model.state_dict().keys():
62
+ if key in state_dict.keys():
63
+ if state_dict[key].shape!=model.state_dict()[key].shape:
64
+ print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
65
+ del state_dict[key]
66
+
67
+ msg = model.load_state_dict(state_dict,strict=False)
68
+ print('load checkpoint from %s'%url_or_filename)
69
+ return model,msg
70
+
ImageReward/models/BLIP/blip_pretrain.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ from torch import nn
9
+ import os
10
+ from .med import BertConfig, BertModel
11
+ from .blip import create_vit, init_tokenizer
12
+
13
+ class BLIP_Pretrain(nn.Module):
14
+ def __init__(self,
15
+ med_config = "med_config.json",
16
+ image_size = 224,
17
+ vit = 'base',
18
+ vit_grad_ckpt = False,
19
+ vit_ckpt_layer = 0,
20
+ embed_dim = 256,
21
+ queue_size = 57600,
22
+ momentum = 0.995,
23
+ ):
24
+ """
25
+ Args:
26
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
27
+ image_size (int): input image size
28
+ vit (str): model size of vision transformer
29
+ """
30
+ super().__init__()
31
+
32
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
33
+
34
+ self.tokenizer = init_tokenizer()
35
+ encoder_config = BertConfig.from_json_file(med_config)
36
+ encoder_config.encoder_width = vision_width
37
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
38
+
39
+ text_width = self.text_encoder.config.hidden_size
40
+
41
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
42
+ self.text_proj = nn.Linear(text_width, embed_dim)
43
+
ImageReward/models/BLIP/med.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on huggingface code base
4
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
5
+ '''
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ from torch import Tensor, device, nn
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ from torch.nn import CrossEntropyLoss
15
+
16
+ from transformers.activations import ACT2FN
17
+ from transformers.file_utils import (
18
+ ModelOutput,
19
+ )
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutputWithPastAndCrossAttentions,
22
+ BaseModelOutputWithPoolingAndCrossAttentions,
23
+ CausalLMOutputWithCrossAttentions,
24
+ MaskedLMOutput,
25
+ MultipleChoiceModelOutput,
26
+ NextSentencePredictorOutput,
27
+ QuestionAnsweringModelOutput,
28
+ SequenceClassifierOutput,
29
+ TokenClassifierOutput,
30
+ )
31
+ from transformers.modeling_utils import (
32
+ PreTrainedModel,
33
+ apply_chunking_to_forward,
34
+ find_pruneable_heads_and_indices,
35
+ prune_linear_layer,
36
+ )
37
+ from transformers.utils import logging
38
+ from transformers.models.bert.configuration_bert import BertConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ class BertEmbeddings(nn.Module):
45
+ """Construct the embeddings from word and position embeddings."""
46
+
47
+ def __init__(self, config):
48
+ super().__init__()
49
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
50
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
51
+
52
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
53
+ # any TensorFlow checkpoint file
54
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
56
+
57
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
58
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
59
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
60
+
61
+ self.config = config
62
+
63
+ def forward(
64
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
65
+ ):
66
+ if input_ids is not None:
67
+ input_shape = input_ids.size()
68
+ else:
69
+ input_shape = inputs_embeds.size()[:-1]
70
+
71
+ seq_length = input_shape[1]
72
+
73
+ if position_ids is None:
74
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
75
+
76
+ if inputs_embeds is None:
77
+ inputs_embeds = self.word_embeddings(input_ids)
78
+
79
+ embeddings = inputs_embeds
80
+
81
+ if self.position_embedding_type == "absolute":
82
+ position_embeddings = self.position_embeddings(position_ids)
83
+ embeddings += position_embeddings
84
+ embeddings = self.LayerNorm(embeddings)
85
+ embeddings = self.dropout(embeddings)
86
+ return embeddings
87
+
88
+
89
+ class BertSelfAttention(nn.Module):
90
+ def __init__(self, config, is_cross_attention):
91
+ super().__init__()
92
+ self.config = config
93
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
94
+ raise ValueError(
95
+ "The hidden size (%d) is not a multiple of the number of attention "
96
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
97
+ )
98
+
99
+ self.num_attention_heads = config.num_attention_heads
100
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
101
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
102
+
103
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
104
+ if is_cross_attention:
105
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
106
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
107
+ else:
108
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
109
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
110
+
111
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
112
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
113
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
114
+ self.max_position_embeddings = config.max_position_embeddings
115
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
116
+ self.save_attention = False
117
+
118
+ def save_attn_gradients(self, attn_gradients):
119
+ self.attn_gradients = attn_gradients
120
+
121
+ def get_attn_gradients(self):
122
+ return self.attn_gradients
123
+
124
+ def save_attention_map(self, attention_map):
125
+ self.attention_map = attention_map
126
+
127
+ def get_attention_map(self):
128
+ return self.attention_map
129
+
130
+ def transpose_for_scores(self, x):
131
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
132
+ x = x.view(*new_x_shape)
133
+ return x.permute(0, 2, 1, 3)
134
+
135
+ def forward(
136
+ self,
137
+ hidden_states,
138
+ attention_mask=None,
139
+ head_mask=None,
140
+ encoder_hidden_states=None,
141
+ encoder_attention_mask=None,
142
+ past_key_value=None,
143
+ output_attentions=False,
144
+ ):
145
+ mixed_query_layer = self.query(hidden_states)
146
+
147
+ # If this is instantiated as a cross-attention module, the keys
148
+ # and values come from an encoder; the attention mask needs to be
149
+ # such that the encoder's padding tokens are not attended to.
150
+ is_cross_attention = encoder_hidden_states is not None
151
+
152
+ if is_cross_attention:
153
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
154
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
155
+ attention_mask = encoder_attention_mask
156
+ elif past_key_value is not None:
157
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
158
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
159
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
160
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
161
+ else:
162
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
163
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
164
+
165
+ query_layer = self.transpose_for_scores(mixed_query_layer)
166
+
167
+ past_key_value = (key_layer, value_layer)
168
+
169
+ # Take the dot product between "query" and "key" to get the raw attention scores.
170
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
171
+
172
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
173
+ seq_length = hidden_states.size()[1]
174
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
175
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
176
+ distance = position_ids_l - position_ids_r
177
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
178
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
179
+
180
+ if self.position_embedding_type == "relative_key":
181
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
182
+ attention_scores = attention_scores + relative_position_scores
183
+ elif self.position_embedding_type == "relative_key_query":
184
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
185
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
186
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
187
+
188
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
189
+ if attention_mask is not None:
190
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
191
+ attention_scores = attention_scores + attention_mask
192
+
193
+ # Normalize the attention scores to probabilities.
194
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
195
+
196
+ if is_cross_attention and self.save_attention:
197
+ self.save_attention_map(attention_probs)
198
+ attention_probs.register_hook(self.save_attn_gradients)
199
+
200
+ # This is actually dropping out entire tokens to attend to, which might
201
+ # seem a bit unusual, but is taken from the original Transformer paper.
202
+ attention_probs_dropped = self.dropout(attention_probs)
203
+
204
+ # Mask heads if we want to
205
+ if head_mask is not None:
206
+ attention_probs_dropped = attention_probs_dropped * head_mask
207
+
208
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
209
+
210
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
211
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
212
+ context_layer = context_layer.view(*new_context_layer_shape)
213
+
214
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
215
+
216
+ outputs = outputs + (past_key_value,)
217
+ return outputs
218
+
219
+
220
+ class BertSelfOutput(nn.Module):
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
224
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
225
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
226
+
227
+ def forward(self, hidden_states, input_tensor):
228
+ hidden_states = self.dense(hidden_states)
229
+ hidden_states = self.dropout(hidden_states)
230
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
231
+ return hidden_states
232
+
233
+
234
+ class BertAttention(nn.Module):
235
+ def __init__(self, config, is_cross_attention=False):
236
+ super().__init__()
237
+ self.self = BertSelfAttention(config, is_cross_attention)
238
+ self.output = BertSelfOutput(config)
239
+ self.pruned_heads = set()
240
+
241
+ def prune_heads(self, heads):
242
+ if len(heads) == 0:
243
+ return
244
+ heads, index = find_pruneable_heads_and_indices(
245
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
246
+ )
247
+
248
+ # Prune linear layers
249
+ self.self.query = prune_linear_layer(self.self.query, index)
250
+ self.self.key = prune_linear_layer(self.self.key, index)
251
+ self.self.value = prune_linear_layer(self.self.value, index)
252
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
253
+
254
+ # Update hyper params and store pruned heads
255
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
256
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
257
+ self.pruned_heads = self.pruned_heads.union(heads)
258
+
259
+ def forward(
260
+ self,
261
+ hidden_states,
262
+ attention_mask=None,
263
+ head_mask=None,
264
+ encoder_hidden_states=None,
265
+ encoder_attention_mask=None,
266
+ past_key_value=None,
267
+ output_attentions=False,
268
+ ):
269
+ self_outputs = self.self(
270
+ hidden_states,
271
+ attention_mask,
272
+ head_mask,
273
+ encoder_hidden_states,
274
+ encoder_attention_mask,
275
+ past_key_value,
276
+ output_attentions,
277
+ )
278
+ attention_output = self.output(self_outputs[0], hidden_states)
279
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
280
+ return outputs
281
+
282
+
283
+ class BertIntermediate(nn.Module):
284
+ def __init__(self, config):
285
+ super().__init__()
286
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
287
+ if isinstance(config.hidden_act, str):
288
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
289
+ else:
290
+ self.intermediate_act_fn = config.hidden_act
291
+
292
+ def forward(self, hidden_states):
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.intermediate_act_fn(hidden_states)
295
+ return hidden_states
296
+
297
+
298
+ class BertOutput(nn.Module):
299
+ def __init__(self, config):
300
+ super().__init__()
301
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
302
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
303
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
304
+
305
+ def forward(self, hidden_states, input_tensor):
306
+ hidden_states = self.dense(hidden_states)
307
+ hidden_states = self.dropout(hidden_states)
308
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
309
+ return hidden_states
310
+
311
+
312
+ class BertLayer(nn.Module):
313
+ def __init__(self, config, layer_num):
314
+ super().__init__()
315
+ self.config = config
316
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
317
+ self.seq_len_dim = 1
318
+ self.attention = BertAttention(config)
319
+ self.layer_num = layer_num
320
+ if self.config.add_cross_attention:
321
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
322
+ self.intermediate = BertIntermediate(config)
323
+ self.output = BertOutput(config)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states,
328
+ attention_mask=None,
329
+ head_mask=None,
330
+ encoder_hidden_states=None,
331
+ encoder_attention_mask=None,
332
+ past_key_value=None,
333
+ output_attentions=False,
334
+ mode=None,
335
+ ):
336
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
337
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
338
+ self_attention_outputs = self.attention(
339
+ hidden_states,
340
+ attention_mask,
341
+ head_mask,
342
+ output_attentions=output_attentions,
343
+ past_key_value=self_attn_past_key_value,
344
+ )
345
+ attention_output = self_attention_outputs[0]
346
+
347
+ outputs = self_attention_outputs[1:-1]
348
+ present_key_value = self_attention_outputs[-1]
349
+
350
+ if mode=='multimodal':
351
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
352
+
353
+ cross_attention_outputs = self.crossattention(
354
+ attention_output,
355
+ attention_mask,
356
+ head_mask,
357
+ encoder_hidden_states,
358
+ encoder_attention_mask,
359
+ output_attentions=output_attentions,
360
+ )
361
+ attention_output = cross_attention_outputs[0]
362
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
363
+ layer_output = apply_chunking_to_forward(
364
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
365
+ )
366
+ outputs = (layer_output,) + outputs
367
+
368
+ outputs = outputs + (present_key_value,)
369
+
370
+ return outputs
371
+
372
+ def feed_forward_chunk(self, attention_output):
373
+ intermediate_output = self.intermediate(attention_output)
374
+ layer_output = self.output(intermediate_output, attention_output)
375
+ return layer_output
376
+
377
+
378
+ class BertEncoder(nn.Module):
379
+ def __init__(self, config):
380
+ super().__init__()
381
+ self.config = config
382
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
383
+ self.gradient_checkpointing = False
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states,
388
+ attention_mask=None,
389
+ head_mask=None,
390
+ encoder_hidden_states=None,
391
+ encoder_attention_mask=None,
392
+ past_key_values=None,
393
+ use_cache=None,
394
+ output_attentions=False,
395
+ output_hidden_states=False,
396
+ return_dict=True,
397
+ mode='multimodal',
398
+ ):
399
+ all_hidden_states = () if output_hidden_states else None
400
+ all_self_attentions = () if output_attentions else None
401
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
402
+
403
+ next_decoder_cache = () if use_cache else None
404
+
405
+ for i in range(self.config.num_hidden_layers):
406
+ layer_module = self.layer[i]
407
+ if output_hidden_states:
408
+ all_hidden_states = all_hidden_states + (hidden_states,)
409
+
410
+ layer_head_mask = head_mask[i] if head_mask is not None else None
411
+ past_key_value = past_key_values[i] if past_key_values is not None else None
412
+
413
+ if self.gradient_checkpointing and self.training:
414
+
415
+ if use_cache:
416
+ logger.warn(
417
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
418
+ )
419
+ use_cache = False
420
+
421
+ def create_custom_forward(module):
422
+ def custom_forward(*inputs):
423
+ return module(*inputs, past_key_value, output_attentions)
424
+
425
+ return custom_forward
426
+
427
+ layer_outputs = torch.utils.checkpoint.checkpoint(
428
+ create_custom_forward(layer_module),
429
+ hidden_states,
430
+ attention_mask,
431
+ layer_head_mask,
432
+ encoder_hidden_states,
433
+ encoder_attention_mask,
434
+ mode=mode,
435
+ )
436
+ else:
437
+ layer_outputs = layer_module(
438
+ hidden_states,
439
+ attention_mask,
440
+ layer_head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ past_key_value,
444
+ output_attentions,
445
+ mode=mode,
446
+ )
447
+
448
+ hidden_states = layer_outputs[0]
449
+ if use_cache:
450
+ next_decoder_cache += (layer_outputs[-1],)
451
+ if output_attentions:
452
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
453
+
454
+ if output_hidden_states:
455
+ all_hidden_states = all_hidden_states + (hidden_states,)
456
+
457
+ if not return_dict:
458
+ return tuple(
459
+ v
460
+ for v in [
461
+ hidden_states,
462
+ next_decoder_cache,
463
+ all_hidden_states,
464
+ all_self_attentions,
465
+ all_cross_attentions,
466
+ ]
467
+ if v is not None
468
+ )
469
+ return BaseModelOutputWithPastAndCrossAttentions(
470
+ last_hidden_state=hidden_states,
471
+ past_key_values=next_decoder_cache,
472
+ hidden_states=all_hidden_states,
473
+ attentions=all_self_attentions,
474
+ cross_attentions=all_cross_attentions,
475
+ )
476
+
477
+
478
+ class BertPooler(nn.Module):
479
+ def __init__(self, config):
480
+ super().__init__()
481
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
482
+ self.activation = nn.Tanh()
483
+
484
+ def forward(self, hidden_states):
485
+ # We "pool" the model by simply taking the hidden state corresponding
486
+ # to the first token.
487
+ first_token_tensor = hidden_states[:, 0]
488
+ pooled_output = self.dense(first_token_tensor)
489
+ pooled_output = self.activation(pooled_output)
490
+ return pooled_output
491
+
492
+
493
+ class BertPredictionHeadTransform(nn.Module):
494
+ def __init__(self, config):
495
+ super().__init__()
496
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
497
+ if isinstance(config.hidden_act, str):
498
+ self.transform_act_fn = ACT2FN[config.hidden_act]
499
+ else:
500
+ self.transform_act_fn = config.hidden_act
501
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
502
+
503
+ def forward(self, hidden_states):
504
+ hidden_states = self.dense(hidden_states)
505
+ hidden_states = self.transform_act_fn(hidden_states)
506
+ hidden_states = self.LayerNorm(hidden_states)
507
+ return hidden_states
508
+
509
+
510
+ class BertLMPredictionHead(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.transform = BertPredictionHeadTransform(config)
514
+
515
+ # The output weights are the same as the input embeddings, but there is
516
+ # an output-only bias for each token.
517
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
518
+
519
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
520
+
521
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
522
+ self.decoder.bias = self.bias
523
+
524
+ def forward(self, hidden_states):
525
+ hidden_states = self.transform(hidden_states)
526
+ hidden_states = self.decoder(hidden_states)
527
+ return hidden_states
528
+
529
+
530
+ class BertOnlyMLMHead(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.predictions = BertLMPredictionHead(config)
534
+
535
+ def forward(self, sequence_output):
536
+ prediction_scores = self.predictions(sequence_output)
537
+ return prediction_scores
538
+
539
+
540
+ class BertPreTrainedModel(PreTrainedModel):
541
+ """
542
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
543
+ models.
544
+ """
545
+
546
+ config_class = BertConfig
547
+ base_model_prefix = "bert"
548
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
549
+
550
+ def _init_weights(self, module):
551
+ """ Initialize the weights """
552
+ if isinstance(module, (nn.Linear, nn.Embedding)):
553
+ # Slightly different from the TF version which uses truncated_normal for initialization
554
+ # cf https://github.com/pytorch/pytorch/pull/5617
555
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
556
+ elif isinstance(module, nn.LayerNorm):
557
+ module.bias.data.zero_()
558
+ module.weight.data.fill_(1.0)
559
+ if isinstance(module, nn.Linear) and module.bias is not None:
560
+ module.bias.data.zero_()
561
+
562
+
563
+ class BertModel(BertPreTrainedModel):
564
+ """
565
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
566
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
567
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
568
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
569
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
570
+ input to the forward pass.
571
+ """
572
+
573
+ def __init__(self, config, add_pooling_layer=True):
574
+ super().__init__(config)
575
+ self.config = config
576
+
577
+ self.embeddings = BertEmbeddings(config)
578
+
579
+ self.encoder = BertEncoder(config)
580
+
581
+ self.pooler = BertPooler(config) if add_pooling_layer else None
582
+
583
+ self.init_weights()
584
+
585
+
586
+ def get_input_embeddings(self):
587
+ return self.embeddings.word_embeddings
588
+
589
+ def set_input_embeddings(self, value):
590
+ self.embeddings.word_embeddings = value
591
+
592
+ def _prune_heads(self, heads_to_prune):
593
+ """
594
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
595
+ class PreTrainedModel
596
+ """
597
+ for layer, heads in heads_to_prune.items():
598
+ self.encoder.layer[layer].attention.prune_heads(heads)
599
+
600
+
601
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
602
+ """
603
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
604
+
605
+ Arguments:
606
+ attention_mask (:obj:`torch.Tensor`):
607
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
608
+ input_shape (:obj:`Tuple[int]`):
609
+ The shape of the input to the model.
610
+ device: (:obj:`torch.device`):
611
+ The device of the input to the model.
612
+
613
+ Returns:
614
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
615
+ """
616
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
617
+ # ourselves in which case we just need to make it broadcastable to all heads.
618
+ if attention_mask.dim() == 3:
619
+ extended_attention_mask = attention_mask[:, None, :, :]
620
+ elif attention_mask.dim() == 2:
621
+ # Provided a padding mask of dimensions [batch_size, seq_length]
622
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
623
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
624
+ if is_decoder:
625
+ batch_size, seq_length = input_shape
626
+
627
+ seq_ids = torch.arange(seq_length, device=device)
628
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
629
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
630
+ # causal and attention masks must have same type with pytorch version < 1.3
631
+ causal_mask = causal_mask.to(attention_mask.dtype)
632
+
633
+ if causal_mask.shape[1] < attention_mask.shape[1]:
634
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
635
+ causal_mask = torch.cat(
636
+ [
637
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
638
+ causal_mask,
639
+ ],
640
+ axis=-1,
641
+ )
642
+
643
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
644
+ else:
645
+ extended_attention_mask = attention_mask[:, None, None, :]
646
+ else:
647
+ raise ValueError(
648
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
649
+ input_shape, attention_mask.shape
650
+ )
651
+ )
652
+
653
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
654
+ # masked positions, this operation will create a tensor which is 0.0 for
655
+ # positions we want to attend and -10000.0 for masked positions.
656
+ # Since we are adding it to the raw scores before the softmax, this is
657
+ # effectively the same as removing these entirely.
658
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
659
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
660
+ return extended_attention_mask
661
+
662
+ def forward(
663
+ self,
664
+ input_ids=None,
665
+ attention_mask=None,
666
+ position_ids=None,
667
+ head_mask=None,
668
+ inputs_embeds=None,
669
+ encoder_embeds=None,
670
+ encoder_hidden_states=None,
671
+ encoder_attention_mask=None,
672
+ past_key_values=None,
673
+ use_cache=None,
674
+ output_attentions=None,
675
+ output_hidden_states=None,
676
+ return_dict=None,
677
+ is_decoder=False,
678
+ mode='multimodal',
679
+ ):
680
+ r"""
681
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
682
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
683
+ the model is configured as a decoder.
684
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
685
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
686
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
687
+ - 1 for tokens that are **not masked**,
688
+ - 0 for tokens that are **masked**.
689
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
690
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
691
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
692
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
693
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
694
+ use_cache (:obj:`bool`, `optional`):
695
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
696
+ decoding (see :obj:`past_key_values`).
697
+ """
698
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
699
+ output_hidden_states = (
700
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
701
+ )
702
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
703
+
704
+ if is_decoder:
705
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
706
+ else:
707
+ use_cache = False
708
+
709
+ if input_ids is not None and inputs_embeds is not None:
710
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
711
+ elif input_ids is not None:
712
+ input_shape = input_ids.size()
713
+ batch_size, seq_length = input_shape
714
+ device = input_ids.device
715
+ elif inputs_embeds is not None:
716
+ input_shape = inputs_embeds.size()[:-1]
717
+ batch_size, seq_length = input_shape
718
+ device = inputs_embeds.device
719
+ elif encoder_embeds is not None:
720
+ input_shape = encoder_embeds.size()[:-1]
721
+ batch_size, seq_length = input_shape
722
+ device = encoder_embeds.device
723
+ else:
724
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
725
+
726
+ # past_key_values_length
727
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
728
+
729
+ if attention_mask is None:
730
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
731
+
732
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
733
+ # ourselves in which case we just need to make it broadcastable to all heads.
734
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
735
+ device, is_decoder)
736
+
737
+ # If a 2D or 3D attention mask is provided for the cross-attention
738
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
739
+ if encoder_hidden_states is not None:
740
+ if type(encoder_hidden_states) == list:
741
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
742
+ else:
743
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
744
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
745
+
746
+ if type(encoder_attention_mask) == list:
747
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
748
+ elif encoder_attention_mask is None:
749
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
750
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
751
+ else:
752
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
753
+ else:
754
+ encoder_extended_attention_mask = None
755
+
756
+ # Prepare head mask if needed
757
+ # 1.0 in head_mask indicate we keep the head
758
+ # attention_probs has shape bsz x n_heads x N x N
759
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
760
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
761
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
762
+
763
+ if encoder_embeds is None:
764
+ embedding_output = self.embeddings(
765
+ input_ids=input_ids,
766
+ position_ids=position_ids,
767
+ inputs_embeds=inputs_embeds,
768
+ past_key_values_length=past_key_values_length,
769
+ )
770
+ else:
771
+ embedding_output = encoder_embeds
772
+
773
+ encoder_outputs = self.encoder(
774
+ embedding_output,
775
+ attention_mask=extended_attention_mask,
776
+ head_mask=head_mask,
777
+ encoder_hidden_states=encoder_hidden_states,
778
+ encoder_attention_mask=encoder_extended_attention_mask,
779
+ past_key_values=past_key_values,
780
+ use_cache=use_cache,
781
+ output_attentions=output_attentions,
782
+ output_hidden_states=output_hidden_states,
783
+ return_dict=return_dict,
784
+ mode=mode,
785
+ )
786
+ sequence_output = encoder_outputs[0]
787
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
788
+
789
+ if not return_dict:
790
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
791
+
792
+ return BaseModelOutputWithPoolingAndCrossAttentions(
793
+ last_hidden_state=sequence_output,
794
+ pooler_output=pooled_output,
795
+ past_key_values=encoder_outputs.past_key_values,
796
+ hidden_states=encoder_outputs.hidden_states,
797
+ attentions=encoder_outputs.attentions,
798
+ cross_attentions=encoder_outputs.cross_attentions,
799
+ )
800
+
801
+
802
+
803
+ class BertLMHeadModel(BertPreTrainedModel):
804
+
805
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
806
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
807
+
808
+ def __init__(self, config):
809
+ super().__init__(config)
810
+
811
+ self.bert = BertModel(config, add_pooling_layer=False)
812
+ self.cls = BertOnlyMLMHead(config)
813
+
814
+ self.init_weights()
815
+
816
+ def get_output_embeddings(self):
817
+ return self.cls.predictions.decoder
818
+
819
+ def set_output_embeddings(self, new_embeddings):
820
+ self.cls.predictions.decoder = new_embeddings
821
+
822
+ def forward(
823
+ self,
824
+ input_ids=None,
825
+ attention_mask=None,
826
+ position_ids=None,
827
+ head_mask=None,
828
+ inputs_embeds=None,
829
+ encoder_hidden_states=None,
830
+ encoder_attention_mask=None,
831
+ labels=None,
832
+ past_key_values=None,
833
+ use_cache=None,
834
+ output_attentions=None,
835
+ output_hidden_states=None,
836
+ return_dict=None,
837
+ return_logits=False,
838
+ is_decoder=True,
839
+ reduction='mean',
840
+ mode='multimodal',
841
+ ):
842
+ r"""
843
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
844
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
845
+ the model is configured as a decoder.
846
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
847
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
848
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
849
+ - 1 for tokens that are **not masked**,
850
+ - 0 for tokens that are **masked**.
851
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
852
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
853
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
854
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
855
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
856
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
857
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
858
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
859
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
860
+ use_cache (:obj:`bool`, `optional`):
861
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
862
+ decoding (see :obj:`past_key_values`).
863
+ Returns:
864
+ Example::
865
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
866
+ >>> import torch
867
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
868
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
869
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
870
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
871
+ >>> outputs = model(**inputs)
872
+ >>> prediction_logits = outputs.logits
873
+ """
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+ if labels is not None:
876
+ use_cache = False
877
+
878
+ outputs = self.bert(
879
+ input_ids,
880
+ attention_mask=attention_mask,
881
+ position_ids=position_ids,
882
+ head_mask=head_mask,
883
+ inputs_embeds=inputs_embeds,
884
+ encoder_hidden_states=encoder_hidden_states,
885
+ encoder_attention_mask=encoder_attention_mask,
886
+ past_key_values=past_key_values,
887
+ use_cache=use_cache,
888
+ output_attentions=output_attentions,
889
+ output_hidden_states=output_hidden_states,
890
+ return_dict=return_dict,
891
+ is_decoder=is_decoder,
892
+ mode=mode,
893
+ )
894
+
895
+ sequence_output = outputs[0]
896
+ prediction_scores = self.cls(sequence_output)
897
+
898
+ if return_logits:
899
+ return prediction_scores[:, :-1, :].contiguous()
900
+
901
+ lm_loss = None
902
+ if labels is not None:
903
+ # we are doing next-token prediction; shift prediction scores and input ids by one
904
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
905
+ labels = labels[:, 1:].contiguous()
906
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
907
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
908
+ if reduction=='none':
909
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
910
+
911
+ if not return_dict:
912
+ output = (prediction_scores,) + outputs[2:]
913
+ return ((lm_loss,) + output) if lm_loss is not None else output
914
+
915
+ return CausalLMOutputWithCrossAttentions(
916
+ loss=lm_loss,
917
+ logits=prediction_scores,
918
+ past_key_values=outputs.past_key_values,
919
+ hidden_states=outputs.hidden_states,
920
+ attentions=outputs.attentions,
921
+ cross_attentions=outputs.cross_attentions,
922
+ )
923
+
924
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
925
+ input_shape = input_ids.shape
926
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
927
+ if attention_mask is None:
928
+ attention_mask = input_ids.new_ones(input_shape)
929
+
930
+ # cut decoder_input_ids if past is used
931
+ if past is not None:
932
+ input_ids = input_ids[:, -1:]
933
+
934
+ return {
935
+ "input_ids": input_ids,
936
+ "attention_mask": attention_mask,
937
+ "past_key_values": past,
938
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
939
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
940
+ "is_decoder": True,
941
+ }
942
+
943
+ def _reorder_cache(self, past, beam_idx):
944
+ reordered_past = ()
945
+ for layer_past in past:
946
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
947
+ return reordered_past
ImageReward/models/BLIP/vit.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on timm code base
4
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
5
+ '''
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from functools import partial
11
+
12
+ from timm.models.vision_transformer import _cfg, PatchEmbed
13
+ from timm.models.registry import register_model
14
+ from timm.models.layers import trunc_normal_, DropPath
15
+ from timm.models.helpers import named_apply, adapt_input_conv
16
+
17
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
18
+
19
+ class Mlp(nn.Module):
20
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
21
+ """
22
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.fc1 = nn.Linear(in_features, hidden_features)
27
+ self.act = act_layer()
28
+ self.fc2 = nn.Linear(hidden_features, out_features)
29
+ self.drop = nn.Dropout(drop)
30
+
31
+ def forward(self, x):
32
+ x = self.fc1(x)
33
+ x = self.act(x)
34
+ x = self.drop(x)
35
+ x = self.fc2(x)
36
+ x = self.drop(x)
37
+ return x
38
+
39
+
40
+ class Attention(nn.Module):
41
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
42
+ super().__init__()
43
+ self.num_heads = num_heads
44
+ head_dim = dim // num_heads
45
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
46
+ self.scale = qk_scale or head_dim ** -0.5
47
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
48
+ self.attn_drop = nn.Dropout(attn_drop)
49
+ self.proj = nn.Linear(dim, dim)
50
+ self.proj_drop = nn.Dropout(proj_drop)
51
+ self.attn_gradients = None
52
+ self.attention_map = None
53
+
54
+ def save_attn_gradients(self, attn_gradients):
55
+ self.attn_gradients = attn_gradients
56
+
57
+ def get_attn_gradients(self):
58
+ return self.attn_gradients
59
+
60
+ def save_attention_map(self, attention_map):
61
+ self.attention_map = attention_map
62
+
63
+ def get_attention_map(self):
64
+ return self.attention_map
65
+
66
+ def forward(self, x, register_hook=False):
67
+ B, N, C = x.shape
68
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
70
+
71
+ attn = (q @ k.transpose(-2, -1)) * self.scale
72
+ attn = attn.softmax(dim=-1)
73
+ attn = self.attn_drop(attn)
74
+
75
+ if register_hook:
76
+ self.save_attention_map(attn)
77
+ attn.register_hook(self.save_attn_gradients)
78
+
79
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
80
+ x = self.proj(x)
81
+ x = self.proj_drop(x)
82
+ return x
83
+
84
+
85
+ class Block(nn.Module):
86
+
87
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
88
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
89
+ super().__init__()
90
+ self.norm1 = norm_layer(dim)
91
+ self.attn = Attention(
92
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
93
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+ self.norm2 = norm_layer(dim)
96
+ mlp_hidden_dim = int(dim * mlp_ratio)
97
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
98
+
99
+ if use_grad_checkpointing:
100
+ self.attn = checkpoint_wrapper(self.attn)
101
+ self.mlp = checkpoint_wrapper(self.mlp)
102
+
103
+ def forward(self, x, register_hook=False):
104
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
105
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
106
+ return x
107
+
108
+
109
+ class VisionTransformer(nn.Module):
110
+ """ Vision Transformer
111
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
112
+ https://arxiv.org/abs/2010.11929
113
+ """
114
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
115
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
116
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
117
+ use_grad_checkpointing=False, ckpt_layer=0):
118
+ """
119
+ Args:
120
+ img_size (int, tuple): input image size
121
+ patch_size (int, tuple): patch size
122
+ in_chans (int): number of input channels
123
+ num_classes (int): number of classes for classification head
124
+ embed_dim (int): embedding dimension
125
+ depth (int): depth of transformer
126
+ num_heads (int): number of attention heads
127
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
128
+ qkv_bias (bool): enable bias for qkv if True
129
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
130
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
131
+ drop_rate (float): dropout rate
132
+ attn_drop_rate (float): attention dropout rate
133
+ drop_path_rate (float): stochastic depth rate
134
+ norm_layer: (nn.Module): normalization layer
135
+ """
136
+ super().__init__()
137
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
138
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
139
+
140
+ self.patch_embed = PatchEmbed(
141
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
142
+
143
+ num_patches = self.patch_embed.num_patches
144
+
145
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
146
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
147
+ self.pos_drop = nn.Dropout(p=drop_rate)
148
+
149
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
150
+ self.blocks = nn.ModuleList([
151
+ Block(
152
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
153
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
154
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
155
+ )
156
+ for i in range(depth)])
157
+ self.norm = norm_layer(embed_dim)
158
+
159
+ trunc_normal_(self.pos_embed, std=.02)
160
+ trunc_normal_(self.cls_token, std=.02)
161
+ self.apply(self._init_weights)
162
+
163
+ def _init_weights(self, m):
164
+ if isinstance(m, nn.Linear):
165
+ trunc_normal_(m.weight, std=.02)
166
+ if isinstance(m, nn.Linear) and m.bias is not None:
167
+ nn.init.constant_(m.bias, 0)
168
+ elif isinstance(m, nn.LayerNorm):
169
+ nn.init.constant_(m.bias, 0)
170
+ nn.init.constant_(m.weight, 1.0)
171
+
172
+ @torch.jit.ignore
173
+ def no_weight_decay(self):
174
+ return {'pos_embed', 'cls_token'}
175
+
176
+ def forward(self, x, register_blk=-1):
177
+ B = x.shape[0]
178
+ x = self.patch_embed(x)
179
+
180
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
181
+ x = torch.cat((cls_tokens, x), dim=1)
182
+
183
+ x = x + self.pos_embed[:,:x.size(1),:]
184
+ x = self.pos_drop(x)
185
+
186
+ for i,blk in enumerate(self.blocks):
187
+ x = blk(x, register_blk==i)
188
+ x = self.norm(x)
189
+
190
+ return x
191
+
192
+ @torch.jit.ignore()
193
+ def load_pretrained(self, checkpoint_path, prefix=''):
194
+ _load_weights(self, checkpoint_path, prefix)
195
+
196
+
197
+ @torch.no_grad()
198
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
199
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
200
+ """
201
+ import numpy as np
202
+
203
+ def _n2p(w, t=True):
204
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
205
+ w = w.flatten()
206
+ if t:
207
+ if w.ndim == 4:
208
+ w = w.transpose([3, 2, 0, 1])
209
+ elif w.ndim == 3:
210
+ w = w.transpose([2, 0, 1])
211
+ elif w.ndim == 2:
212
+ w = w.transpose([1, 0])
213
+ return torch.from_numpy(w)
214
+
215
+ w = np.load(checkpoint_path)
216
+ if not prefix and 'opt/target/embedding/kernel' in w:
217
+ prefix = 'opt/target/'
218
+
219
+ if hasattr(model.patch_embed, 'backbone'):
220
+ # hybrid
221
+ backbone = model.patch_embed.backbone
222
+ stem_only = not hasattr(backbone, 'stem')
223
+ stem = backbone if stem_only else backbone.stem
224
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
225
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
226
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
227
+ if not stem_only:
228
+ for i, stage in enumerate(backbone.stages):
229
+ for j, block in enumerate(stage.blocks):
230
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
231
+ for r in range(3):
232
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
233
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
234
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
235
+ if block.downsample is not None:
236
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
237
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
238
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
239
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
240
+ else:
241
+ embed_conv_w = adapt_input_conv(
242
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
243
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
244
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
245
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
246
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
247
+ if pos_embed_w.shape != model.pos_embed.shape:
248
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
249
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
250
+ model.pos_embed.copy_(pos_embed_w)
251
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
252
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
253
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
254
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
255
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
256
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
257
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
258
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
259
+ for i, block in enumerate(model.blocks.children()):
260
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
261
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
262
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
263
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
264
+ block.attn.qkv.weight.copy_(torch.cat([
265
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
266
+ block.attn.qkv.bias.copy_(torch.cat([
267
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
268
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
269
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
270
+ for r in range(2):
271
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
272
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
273
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
274
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
275
+
276
+
277
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
278
+ # interpolate position embedding
279
+ embedding_size = pos_embed_checkpoint.shape[-1]
280
+ num_patches = visual_encoder.patch_embed.num_patches
281
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
282
+ # height (== width) for the checkpoint position embedding
283
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
284
+ # height (== width) for the new position embedding
285
+ new_size = int(num_patches ** 0.5)
286
+
287
+ if orig_size!=new_size:
288
+ # class_token and dist_token are kept unchanged
289
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
290
+ # only the position tokens are interpolated
291
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
292
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
293
+ pos_tokens = torch.nn.functional.interpolate(
294
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
295
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
296
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
297
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
298
+
299
+ return new_pos_embed
300
+ else:
301
+ return pos_embed_checkpoint
ImageReward/models/BLIPScore.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : BLIPScore.py
3
+ @Time : 2023/02/19 20:48:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: BLIPScore.
7
+ * Based on BLIP code base
8
+ * https://github.com/salesforce/BLIP
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ from ImageReward.models.BLIP.blip_pretrain import BLIP_Pretrain
16
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
17
+
18
+ try:
19
+ from torchvision.transforms import InterpolationMode
20
+ BICUBIC = InterpolationMode.BICUBIC
21
+ except ImportError:
22
+ BICUBIC = Image.BICUBIC
23
+
24
+
25
+ def _convert_image_to_rgb(image):
26
+ return image.convert("RGB")
27
+
28
+
29
+ def _transform(n_px):
30
+ return Compose([
31
+ Resize(n_px, interpolation=BICUBIC),
32
+ CenterCrop(n_px),
33
+ _convert_image_to_rgb,
34
+ ToTensor(),
35
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
36
+ ])
37
+
38
+
39
+ class BLIPScore(nn.Module):
40
+ def __init__(self, med_config, device='cpu'):
41
+ super().__init__()
42
+ self.device = device
43
+
44
+ self.preprocess = _transform(224)
45
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config)
46
+
47
+
48
+ def score(self, prompt, image_path):
49
+
50
+ if (type(image_path).__name__=='list'):
51
+ _, rewards = self.inference_rank(prompt, image_path)
52
+ return rewards
53
+
54
+ # text encode
55
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
56
+ text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
57
+ txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:]))
58
+
59
+ # image encode
60
+ pil_image = Image.open(image_path)
61
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
62
+ image_embeds = self.blip.visual_encoder(image)
63
+ image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1)
64
+
65
+ # score
66
+ rewards = torch.sum(torch.mul(txt_feature, image_features), dim=1, keepdim=True)
67
+
68
+ return rewards.detach().cpu().numpy().item()
69
+
70
+
71
+ def inference_rank(self, prompt, generations_list):
72
+
73
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
74
+ text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
75
+ txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:]))
76
+
77
+ txt_set = []
78
+ img_set = []
79
+ for generations in generations_list:
80
+ # image encode
81
+ img_path = generations
82
+ pil_image = Image.open(img_path)
83
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
84
+ image_embeds = self.blip.visual_encoder(image)
85
+ image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1)
86
+ img_set.append(image_features)
87
+ txt_set.append(txt_feature)
88
+
89
+ txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
90
+ img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
91
+ rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
92
+ rewards = torch.squeeze(rewards)
93
+ _, rank = torch.sort(rewards, dim=0, descending=True)
94
+ _, indices = torch.sort(rank, dim=0)
95
+ indices = indices + 1
96
+
97
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
ImageReward/models/CLIPScore.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : CLIPScore.py
3
+ @Time : 2023/02/12 13:14:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ @Description: CLIPScore.
7
+ * Based on CLIP code base
8
+ * https://github.com/openai/CLIP
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ import clip
16
+
17
+ class CLIPScore(nn.Module):
18
+ def __init__(self, download_root, device='cpu'):
19
+ super().__init__()
20
+ self.device = device
21
+ self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
22
+ download_root=download_root)
23
+
24
+ if device == "cpu":
25
+ self.clip_model.float()
26
+ else:
27
+ clip.model.convert_weights(self.clip_model) # Actually this line is unnecessary since clip by default already on float16
28
+
29
+ # have clip.logit_scale require no grad.
30
+ self.clip_model.logit_scale.requires_grad_(False)
31
+
32
+
33
+ def score(self, prompt, image_path):
34
+
35
+ if (type(image_path).__name__=='list'):
36
+ _, rewards = self.inference_rank(prompt, image_path)
37
+ return rewards
38
+
39
+ # text encode
40
+ text = clip.tokenize(prompt, truncate=True).to(self.device)
41
+ txt_features = F.normalize(self.clip_model.encode_text(text))
42
+
43
+ # image encode
44
+ pil_image = Image.open(image_path)
45
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
46
+ image_features = F.normalize(self.clip_model.encode_image(image))
47
+
48
+ # score
49
+ rewards = torch.sum(torch.mul(txt_features, image_features), dim=1, keepdim=True)
50
+
51
+ return rewards.detach().cpu().numpy().item()
52
+
53
+
54
+ def inference_rank(self, prompt, generations_list):
55
+
56
+ text = clip.tokenize(prompt, truncate=True).to(self.device)
57
+ txt_feature = F.normalize(self.clip_model.encode_text(text))
58
+
59
+ txt_set = []
60
+ img_set = []
61
+ for generations in generations_list:
62
+ # image encode
63
+ img_path = generations
64
+ pil_image = Image.open(img_path)
65
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
66
+ image_features = F.normalize(self.clip_model.encode_image(image))
67
+ img_set.append(image_features)
68
+ txt_set.append(txt_feature)
69
+
70
+ txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim]
71
+ img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim]
72
+ rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True)
73
+ rewards = torch.squeeze(rewards)
74
+ _, rank = torch.sort(rewards, dim=0, descending=True)
75
+ _, indices = torch.sort(rank, dim=0)
76
+ indices = indices + 1
77
+
78
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
ImageReward/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .AestheticScore import *
2
+ from .BLIPScore import *
3
+ from .CLIPScore import *
4
+ from .BLIP import *
ImageReward/utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : utils.py
3
+ @Time : 2023/04/05 19:18:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : [email protected]
6
+ * Based on CLIP code base
7
+ * https://github.com/openai/CLIP
8
+ * Checkpoint of CLIP/BLIP/Aesthetic are from:
9
+ * https://github.com/openai/CLIP
10
+ * https://github.com/salesforce/BLIP
11
+ * https://github.com/christophschuhmann/improved-aesthetic-predictor
12
+ '''
13
+
14
+ import os
15
+ import urllib
16
+ from typing import Union, List
17
+ import pathlib
18
+
19
+ import torch
20
+ from tqdm import tqdm
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ from .ImageReward import ImageReward
24
+ from .models.CLIPScore import CLIPScore
25
+ from .models.BLIPScore import BLIPScore
26
+ from .models.AestheticScore import AestheticScore
27
+
28
+ _MODELS = {
29
+ "ImageReward-v1.0": "https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt",
30
+ }
31
+
32
+
33
+ def available_models() -> List[str]:
34
+ """Returns the names of available ImageReward models"""
35
+ return list(_MODELS.keys())
36
+
37
+
38
+ def ImageReward_download(url: str, root: str):
39
+ os.makedirs(root, exist_ok=True)
40
+ filename = os.path.basename(url)
41
+ download_target = os.path.join(root, filename)
42
+ hf_hub_download(repo_id="THUDM/ImageReward", filename=filename, local_dir=root)
43
+ return download_target
44
+
45
+
46
+ def load(name: str = "ImageReward-v1.0",
47
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
48
+ download_root: str = None,
49
+ med_config_path: str = None):
50
+ """Load a ImageReward model
51
+
52
+ Parameters
53
+ ----------
54
+ name: str
55
+ A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict
56
+ device: Union[str, torch.device]
57
+ The device to put the loaded model
58
+ download_root: str
59
+ path to download the model files; by default, it uses "~/.cache/ImageReward"
60
+ med_config_path: str
61
+
62
+ Returns
63
+ -------
64
+ model : torch.nn.Module
65
+ The ImageReward model
66
+ """
67
+ if name in _MODELS:
68
+ download_root = download_root or "~/.cache/ImageReward"
69
+ download_root = pathlib.Path(download_root)
70
+ model_path = pathlib.Path(download_root) / 'ImageReward.pt'
71
+
72
+ if not model_path.exists():
73
+ model_path = ImageReward_download(_MODELS[name], root=download_root.as_posix())
74
+ elif os.path.isfile(name):
75
+ model_path = name
76
+ else:
77
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
78
+
79
+ print('-> load ImageReward model from %s' % model_path)
80
+ state_dict = torch.load(model_path, map_location='cpu')
81
+
82
+ # med_config
83
+ if med_config_path is None:
84
+ med_config_root = download_root or "~/.cache/ImageReward"
85
+ med_config_root = pathlib.Path(med_config_root)
86
+ med_config_path = med_config_root / 'med_config.json'
87
+
88
+ if not med_config_path.exists():
89
+ med_config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
90
+ root=med_config_root.as_posix())
91
+ print('-> load ImageReward med_config from %s' % med_config_path)
92
+
93
+ model = ImageReward(device=device, med_config=med_config_path).to(device)
94
+ msg = model.load_state_dict(state_dict, strict=False)
95
+ model.eval()
96
+
97
+ return model
98
+
99
+
100
+ _SCORES = {
101
+ "CLIP": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
102
+ "BLIP": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth",
103
+ "Aesthetic": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth",
104
+ }
105
+
106
+
107
+ def available_scores() -> List[str]:
108
+ """Returns the names of available ImageReward scores"""
109
+ return list(_SCORES.keys())
110
+
111
+
112
+ def _download(url: str, root: str):
113
+ os.makedirs(root, exist_ok=True)
114
+ filename = os.path.basename(url)
115
+
116
+ download_target = os.path.join(root, filename)
117
+
118
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
119
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
120
+
121
+ if os.path.isfile(download_target):
122
+ return download_target
123
+
124
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
125
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
126
+ unit_divisor=1024) as loop:
127
+ while True:
128
+ buffer = source.read(8192)
129
+ if not buffer:
130
+ break
131
+
132
+ output.write(buffer)
133
+ loop.update(len(buffer))
134
+
135
+ return download_target
136
+
137
+
138
+ def load_score(name: str = "CLIP", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
139
+ download_root: str = None):
140
+ """Load a ImageReward model
141
+
142
+ Parameters
143
+ ----------
144
+ name : str
145
+ A model name listed by `ImageReward.available_models()`
146
+
147
+ device : Union[str, torch.device]
148
+ The device to put the loaded model
149
+
150
+ download_root: str
151
+ path to download the model files; by default, it uses "~/.cache/ImageReward"
152
+
153
+ Returns
154
+ -------
155
+ model : torch.nn.Module
156
+ The ImageReward model
157
+ """
158
+ model_download_root = download_root or os.path.expanduser("~/.cache/ImageReward")
159
+
160
+ if name in _SCORES:
161
+ model_path = _download(_SCORES[name], model_download_root)
162
+ else:
163
+ raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
164
+
165
+ print('load checkpoint from %s' % model_path)
166
+ if name == "BLIP":
167
+ state_dict = torch.load(model_path, map_location='cpu')
168
+ med_config = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json",
169
+ model_download_root)
170
+ model = BLIPScore(med_config=med_config, device=device).to(device)
171
+ model.blip.load_state_dict(state_dict['model'], strict=False)
172
+ elif name == "CLIP":
173
+ model = CLIPScore(download_root=model_download_root, device=device).to(device)
174
+ elif name == "Aesthetic":
175
+ state_dict = torch.load(model_path, map_location='cpu')
176
+ model = AestheticScore(download_root=model_download_root, device=device).to(device)
177
+ model.mlp.load_state_dict(state_dict, strict=False)
178
+ else:
179
+ raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}")
180
+
181
+ print("checkpoint loaded")
182
+ model.eval()
183
+
184
+ return model
README.md CHANGED
@@ -1,19 +1,91 @@
1
  # SVGDreamer: Text Guided SVG Generation with Diffusion Model
2
 
 
3
  [![arXiv](https://img.shields.io/badge/arXiv-2312.16476-b31b1b.svg)](https://arxiv.org/abs/2312.16476)
4
- [![website](https://img.shields.io/badge/website-Gitpage-yellow)](https://ximinng.github.io/SVGDreamer-project/)
 
 
5
 
6
- ### Code coming soon !!!
 
7
 
8
- Our project page can be found [here](https://ximinng.github.io/SVGDreamer-project/).
9
 
10
- ![title](./assets/teaser1.png)
11
- ![title](./assets/teaser2.png)
12
- ![title](./assets/teaser3.png)
13
 
14
- ### TODO
15
 
16
- - [ ] release the complete code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  ## :books: Acknowledgement
19
 
@@ -22,6 +94,8 @@ The project is built based on the following repository:
22
  - [BachiLi/diffvg](https://github.com/BachiLi/diffvg)
23
  - [huggingface/diffusers](https://github.com/huggingface/diffusers)
24
  - [ximinng/DiffSketcher](https://github.com/ximinng/DiffSketcher)
 
 
25
 
26
  We gratefully thank the authors for their wonderful works.
27
 
@@ -31,10 +105,10 @@ If you use this code for your research, please cite the following work:
31
 
32
  ```
33
  @article{xing2023svgdreamer,
34
- title={SVGDreamer: Text Guided SVG Generation with Diffusion Model},
35
- author={Xing, Ximing and Zhou, Haitao and Wang, Chuang and Zhang, Jing and Xu, Dong and Yu, Qian},
36
- journal={arXiv preprint arXiv:2312.16476},
37
- year={2023}
38
  }
39
  ```
40
 
 
1
  # SVGDreamer: Text Guided SVG Generation with Diffusion Model
2
 
3
+ [![cvpr24](https://img.shields.io/badge/CVPR-2024-387ADF.svg)](https://arxiv.org/abs/2312.16476)
4
  [![arXiv](https://img.shields.io/badge/arXiv-2312.16476-b31b1b.svg)](https://arxiv.org/abs/2312.16476)
5
+ [![website](https://img.shields.io/badge/Website-Gitpage-4CCD99)](https://ximinng.github.io/SVGDreamer-project/)
6
+ [![blog](https://img.shields.io/badge/Blog-ENG-9195F6)](https://huggingface.co/blog/xingxm/svgdreamer)
7
+ [![blog](https://img.shields.io/badge/Blog-CN-9195F6)](https://huggingface.co/blog/xingxm/svgdreamer)
8
 
9
+ This repository contains our official implementation of the CVPR 2024 paper: SVGDreamer: Text-Guided SVG Generation with
10
+ Diffusion Model. It can generate high-quality SVGs based on text prompts.
11
 
12
+ [//]: # (> Project Page: https://ximinng.github.io/SVGDreamer-project/)
13
 
14
+ ![title](./assets/illustrate.png)
15
+ ![title](./assets/teaser_svg_asset.png)
 
16
 
17
+ ## :new: Update
18
 
19
+ - [03/2024] 🔥 We have released the **code** for [SVGDreamer](https://ximinng.github.io/SVGDreamer-project/).
20
+ - [02/2024] 🎉 **SVGDreamer accepted by CVPR2024.** 🎉
21
+ - [12/2023] 🔥 We have released the **[SVGDreamer Paper](https://arxiv.org/abs/2312.16476)**. SVGDreamer is
22
+ a novel text-guided vector graphics synthesis method. This method considers both the editing of vector graphics and
23
+ the quality of the synthesis.
24
+
25
+ ## 🔥Quickstart
26
+
27
+ Before running the code, download the stable diffusion model. Append `diffuser.download=True` to the end of the script.
28
+
29
+ ### SIVE + VPSD
30
+
31
+ **Script:**
32
+
33
+ ```shell
34
+ python svgdreamer.py x=iconography skip_sive=False "prompt='an image of Batman. full body action pose, complete detailed body. white background. empty background, high quality, 4K, ultra realistic'" token_ind=4 x.vpsd.t_schedule='randint' result_path='./logs/batman' multirun=True mv=True
35
+ ```
36
+
37
+ - `x=iconography`(str): style configs
38
+ - `skip_sive`(bool): enable the SIVE stage
39
+ - `token_ind`(int): the index of text prompt, from 1
40
+ - `result_path`(str): the path to save the result
41
+ - `multirun`(bool): run the script multiple times with different random seeds
42
+ - `mv`(bool): save the intermediate results of the run and record the video (This increases the run time)
43
+
44
+ **More parameters in `./conf/x/style.yaml`, you can modify these parameters from the command line. For
45
+ example, append `x.vpsd.n_particle=4` to the end of the script.**
46
+
47
+ ### VPSD
48
+
49
+ **Prompt:** Sydney opera house. oil painting. by Van Gogh <br/>
50
+ **Style:** iconography <br/>
51
+ **Preview:**
52
+
53
+ | Particle 1 | Particle 2 | Particle 3 | Particle 4 | Particle 5 | Particle 6 |
54
+ |--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|
55
+ | init p1 | init p2 | init p3 | init p4 | init p5 | init p6 |
56
+ | <img src="./assets/Icon-SydneyOperaHouse/init_p0.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p1.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p2.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p3.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p4.svg"> | <img src="./assets/Icon-SydneyOperaHouse/init_p5.svg"> |
57
+ | final p1 | final p2 | final p3 | final p4 | final p5 | final p6 |
58
+ | <img src="./assets/Icon-SydneyOperaHouse/p_0.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_1.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_2.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_3.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_4.svg"> | <img src="assets/Icon-SydneyOperaHouse/p_5.svg"> |
59
+
60
+ **Script:**
61
+
62
+ ```shell
63
+ python svgdreamer.py x=iconography "prompt='Sydney opera house. oil painting. by Van Gogh'" result_path='./logs/SydneyOperaHouse-OilPainting'
64
+ ```
65
+
66
+ **Other Styles:**
67
+
68
+ ```shell
69
+ # Style: low-ploy
70
+ python svgdreamer.py x=lowpoly "prompt='A picture of a bald eagle. low-ploy. polygon'" result_path='./logs/BaldEagle'
71
+ # Style: pixel-art
72
+ python svgdreamer.py x=pixelart "prompt='Darth vader with lightsaber.'" result_path='./log/DarthVader'
73
+ # Style: painting
74
+ python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" result_path='./logs/VanGogh-Portrait'
75
+ # Style: sketch
76
+ python svgdreamer.py x=sketch "prompt='A free-hand drawing of A speeding Lamborghini. black and white drawing.'" result_path='./logs/Lamborghini'
77
+ # Style: ink and wash
78
+ python svgdreamer.py x=ink "prompt='Big Wild Goose Pagoda. ink style. Minimalist abstract art grayscale watercolor.'" result_path='./logs/BigWildGoosePagoda'
79
+ ```
80
+
81
+ ## 🔑 Tips
82
+
83
+ - `x.vpsd.t_schedule` greatly affects the style of the result. Please try more.
84
+ - `neg_prompt` negative prompts affect the quality of the results.
85
+
86
+ ## 📋 TODO
87
+
88
+ - [x] Release the code
89
 
90
  ## :books: Acknowledgement
91
 
 
94
  - [BachiLi/diffvg](https://github.com/BachiLi/diffvg)
95
  - [huggingface/diffusers](https://github.com/huggingface/diffusers)
96
  - [ximinng/DiffSketcher](https://github.com/ximinng/DiffSketcher)
97
+ - [THUDM/ImageReward](https://github.com/THUDM/ImageReward)
98
+ - [ximinng//PyTorch-SVGRender](https://github.com/ximinng/PyTorch-SVGRender)
99
 
100
  We gratefully thank the authors for their wonderful works.
101
 
 
105
 
106
  ```
107
  @article{xing2023svgdreamer,
108
+ title={SVGDreamer: Text Guided SVG Generation with Diffusion Model},
109
+ author={Xing, Ximing and Zhou, Haitao and Wang, Chuang and Zhang, Jing and Xu, Dong and Yu, Qian},
110
+ journal={arXiv preprint arXiv:2312.16476},
111
+ year={2023}
112
  }
113
  ```
114
 
assets/Icon-SydneyOperaHouse/init_p0.svg ADDED
assets/Icon-SydneyOperaHouse/init_p1.svg ADDED
assets/Icon-SydneyOperaHouse/init_p2.svg ADDED
assets/Icon-SydneyOperaHouse/init_p3.svg ADDED
assets/Icon-SydneyOperaHouse/init_p4.svg ADDED
assets/Icon-SydneyOperaHouse/init_p5.svg ADDED
assets/Icon-SydneyOperaHouse/p_0.svg ADDED
assets/Icon-SydneyOperaHouse/p_1.svg ADDED
assets/Icon-SydneyOperaHouse/p_2.svg ADDED
assets/Icon-SydneyOperaHouse/p_3.svg ADDED
assets/Icon-SydneyOperaHouse/p_4.svg ADDED
assets/Icon-SydneyOperaHouse/p_5.svg ADDED
assets/{teaser1.png → illustrate.png} RENAMED
File without changes
assets/{teaser2.png → teaser_cases.png} RENAMED
File without changes
assets/{teaser3.png → teaser_more_cases.png} RENAMED
File without changes
assets/teaser_svg_asset.png ADDED

Git LFS Details

  • SHA256: 69218254bb2274825271c1b283804a026e23990b2da1b19e64d8de6ea8774666
  • Pointer size: 132 Bytes
  • Size of remote file: 3.72 MB
conf/config.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-----------------#
2
+ # Global Config #
3
+ #-----------------#
4
+
5
+ # common args
6
+ prompt: ~
7
+ token_ind: 1 # the index of text prompt, from 1
8
+ neg_prompt: ~ # negative prompt
9
+ skip_sive: True # optimize from scratch without SIVE init
10
+
11
+ # Accelerate config
12
+ state:
13
+ cpu: False # use cpu
14
+ mprec: no # mixed precision, choices: 'no', 'fp16', 'bf16'
15
+
16
+ # Diffusers config
17
+ diffuser:
18
+ download: False # Set this variable to True the first time it runs
19
+ force_download: False
20
+ resume_download: False
21
+
22
+ # PyDiffVG config
23
+ diffvg:
24
+ print_timing: False
25
+
26
+ # reproduction
27
+ seed: 951222
28
+ # multi-run
29
+ multirun: False
30
+ srange: ~ # seed range, example: [100, 100]
31
+
32
+ # log
33
+ result_path: './workspace'
34
+ save_step: 50
35
+
36
+ # visual rendering process
37
+ mv: False # make video
38
+ framefreq: 5 # save the image interval
39
+ framerate: 24 # by adjusting the frame rate, you can control the playback speed of the output video
40
+
41
+ # hydra setting
42
+ hydra:
43
+ help:
44
+ # app name, override to match the name your app is known by
45
+ app_name: 'SVGDreamer'
46
+ run:
47
+ # output directory for normal runs
48
+ # warning: make sure that the L53-55 of './libs/model_state.py' and 'dir' are modified together
49
+ dir: ./${result_path}/SVGDreamer-${now:%Y-%m-%d-%H-%M}
50
+
51
+ # default settings
52
+ defaults:
53
+ - _self_
54
+ - x: ~
conf/x/iconography.yaml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 600 # canvas size
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ color_init: 'rand' # if skip_live=True, then use color_init to init target_img
4
+ style: "iconography" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
5
+
6
+ # stable diffusion in SIVE stage
7
+ sive_model_cfg:
8
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
9
+ ldm_speed_up: False
10
+ enable_xformers: True
11
+ gradient_checkpoint: False
12
+ cpu_offload: True
13
+ num_inference_steps: 100
14
+ guidance_scale: 7.5 # sdxl default 5.0
15
+ lora_path: ~
16
+
17
+ # lr and optim
18
+ sive_stage_optim:
19
+ point: 1 # control points
20
+ width: 0.1 # stroke width
21
+ color: 0.01 # fill color and stroke color
22
+ bg: 0.01 # bg in render_warp
23
+ optim:
24
+ name: 'adam'
25
+ betas: [ 0.9, 0.9 ]
26
+ eps: 1e-6
27
+ schedule:
28
+ name: 'linear'
29
+ keep_ratio: 0.2
30
+ decay_ratio: 0.4
31
+
32
+ # SIVE rendering
33
+ sive:
34
+ attn_cfg: # init content via attn
35
+ cross_attn_res: 16
36
+ self_attn_res: 32
37
+ max_com: 20
38
+ mean_comp: False
39
+ comp_idx: 0
40
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
41
+ bg:
42
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
43
+ num_iter: 10
44
+ num_paths: 256
45
+ path_schedule: 'repeat' # 'repeat', 'list'
46
+ schedule_each: 128
47
+ width: 3 # sketch stroke width
48
+ num_segments: 4
49
+ segment_init: 'circle' # 'random'
50
+ radius: 20
51
+ coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
52
+ grid: 20
53
+ # optim
54
+ lr_schedule: True
55
+ optim_bg: False # train background
56
+ use_attn_init: True
57
+ softmax_tau: 0.3 # temperature of softmax
58
+ # loss
59
+ use_distance_weighted_loss: False
60
+ xing_loss_weight: 0.001
61
+ fg:
62
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
63
+ num_iter: 10
64
+ num_paths: 256 # number of strokes
65
+ path_schedule: 'repeat' # 'repeat', 'list'
66
+ schedule_each: 128
67
+ width: 3 # sketch stroke width
68
+ num_segments: 4
69
+ segment_init: 'circle' # 'random'
70
+ radius: 15
71
+ coord_init: 'random' # 'random', 'naive', place the first control point
72
+ grid: 20
73
+ # optim
74
+ lr_schedule: False
75
+ optim_bg: False # train background
76
+ use_attn_init: True
77
+ softmax_tau: 0.3 # temperature of softmax
78
+ # loss
79
+ use_distance_weighted_loss: False
80
+ xing_loss_weight: 0.01
81
+ tog: # for refinement
82
+ reinit: True # if False, use fg params to init content
83
+ num_iter: 1000
84
+ # optim
85
+ lr_schedule: False # enable lr_scheduler or not
86
+ # loss
87
+ bg_lam: 0
88
+ fg_lam: 1
89
+ xing_loss_weight: 0
90
+
91
+ # VPSD primitives
92
+ num_paths: 512 # number of strokes
93
+ trainable_bg: False # set the background to be trainable
94
+ width: 3 # stroke width
95
+ num_segments: 4
96
+ segment_init: 'circle' # 'random'
97
+ radius: 20
98
+ coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
99
+ grid: 50 # divide the canvas into n grids
100
+ path_reinit: # reinitializing paths
101
+ use: True
102
+ freq: 100 # every 50 iterations
103
+ stop_step: 1000 # for VPSD fine-tuning
104
+ opacity_threshold: 0.05
105
+ area_threshold: 64
106
+
107
+ # lr and optim
108
+ vpsd_stage_optim:
109
+ point: 1
110
+ width: 0.1
111
+ color: 0.01
112
+ bg: 0.01
113
+ lr_schedule: True # use lr_scheduler
114
+ optim:
115
+ name: 'adam'
116
+ betas: [ 0.9, 0.9 ]
117
+ eps: 1e-6
118
+ schedule:
119
+ name: 'cosine'
120
+ warmup_steps: 10
121
+ warmup_start_lr: 0.02
122
+ warmup_end_lr: 0.9
123
+ cosine_end_lr: 0.4
124
+
125
+ # stable diffusion in VPSD stage
126
+ vpsd_model_cfg:
127
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
128
+ ldm_speed_up: False
129
+ enable_xformers: True
130
+ gradient_checkpoint: False
131
+ cpu_offload: True
132
+ num_inference_steps: 100
133
+ guidance_scale: 7.5 # sdxl default 5.0
134
+ lora_path: ~
135
+
136
+ # VPSD setting
137
+ vpsd:
138
+ type: 'vpsd'
139
+ n_particle: 6 # 4, 8, 16
140
+ vsd_n_particle: 4 # the batch size of particles
141
+ particle_aug: False # do data enhancement for the input particles
142
+ num_iter: 2000 # total iterations
143
+ guidance_scale: 7.5 # CFG value
144
+ grad_scale: 1.0 # increase or decrease the gradient
145
+ grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
146
+ t_range: [ 0.02, 0.98 ]
147
+ # 'randint': random time steps, this may have a more authentic style.
148
+ # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
149
+ t_schedule: 'max_0.5_1500' # or 'randint'
150
+ # phi model config
151
+ phi_single: False # if False new an unet model to estimate noise
152
+ phi_model: 'lora' # 'lora', 'unet_simple'
153
+ use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
154
+ lora_attn_scale: 1.0 # the scale of the attn based lora layer
155
+ phi_guidance_scale: 1.0
156
+ phi_t: False # different t for phi fine-tuning
157
+ phi_update_step: 1 # enable multi-update phi model or not
158
+ phi_lr: 0.0001 # learning rate of phi model
159
+ phi_scheduler: 'ddim'
160
+ phi_n_particle: 2 # the batch size of phi_model
161
+ # ReFL config
162
+ phi_ReFL: False # enable reward feed back learning
163
+ n_phi_sample: 1 # number of samples used in ReFL
164
+ phi_sample_step: 200 # the phi log step
165
+ phi_infer_step: 50 # the phi num_inference_steps
166
+ # phi model optim
167
+ phi_optim:
168
+ name: 'adamw'
169
+ betas: [ 0.9, 0.999 ]
170
+ eps: 1e-8
171
+ weight_decay: ~ # 1e-5
172
+ # phi model lr learning schedule
173
+ phi_schedule:
174
+ use: False
175
+ name: 'cosine'
176
+ warmup_steps: 50
177
+ warmup_start_lr: 0.00001
178
+ warmup_end_lr: 0.0001
179
+ total_step: 800
180
+ cosine_end_lr: 0.0001
181
+
182
+ # reward model
183
+ reward_path: './checkpoint/ImageReward'
184
+
185
+ # xing loss for closed-form paths
186
+ xing_loss:
187
+ use: False
188
+ weight: 0.01
conf/x/ink.yaml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 600 # canvas size
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ color_init: 'rand' # if skip_live=True, then use color_init to init target_img
4
+ style: "ink" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
5
+
6
+ # stable diffusion in SIVE stage
7
+ sive_model_cfg:
8
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
9
+ ldm_speed_up: False
10
+ enable_xformers: True
11
+ gradient_checkpoint: False
12
+ cpu_offload: True
13
+ num_inference_steps: 100
14
+ guidance_scale: 7.5 # sdxl default 5.0
15
+ lora_path: ~
16
+
17
+ # lr and optim
18
+ sive_stage_optim:
19
+ point: 1 # control points
20
+ width: 0.1 # stroke width
21
+ color: 0.01 # fill color and stroke color
22
+ bg: 0.01 # bg in render_warp
23
+ optim:
24
+ name: 'adam'
25
+ betas: [ 0.9, 0.9 ]
26
+ eps: 1e-6
27
+ schedule:
28
+ name: 'linear'
29
+ keep_ratio: 0.2
30
+ decay_ratio: 0.4
31
+
32
+ # SIVE rendering
33
+ sive:
34
+ attn_cfg: # init content via attn
35
+ cross_attn_res: 16
36
+ self_attn_res: 32
37
+ max_com: 20
38
+ mean_comp: False
39
+ comp_idx: 0
40
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
41
+ bg:
42
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
43
+ num_iter: 10
44
+ num_paths: 256
45
+ path_schedule: 'repeat' # 'repeat', 'list'
46
+ schedule_each: 128
47
+ width: 3 # sketch stroke width
48
+ num_segments: 4
49
+ segment_init: 'circle' # 'random'
50
+ radius: 20
51
+ coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
52
+ grid: 20
53
+ # optim
54
+ lr_schedule: True
55
+ optim_bg: False # train background
56
+ use_attn_init: True
57
+ softmax_tau: 0.3 # temperature of softmax
58
+ # loss
59
+ use_distance_weighted_loss: False
60
+ xing_loss_weight: 0.001
61
+ fg:
62
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
63
+ num_iter: 10
64
+ num_paths: 256 # number of strokes
65
+ path_schedule: 'repeat' # 'repeat', 'list'
66
+ schedule_each: 128
67
+ width: 3 # sketch stroke width
68
+ num_segments: 4
69
+ segment_init: 'circle' # 'random'
70
+ radius: 15
71
+ coord_init: 'random' # 'random', 'naive', place the first control point
72
+ grid: 20
73
+ # optim
74
+ lr_schedule: False
75
+ optim_bg: False # train background
76
+ use_attn_init: True
77
+ softmax_tau: 0.3 # temperature of softmax
78
+ # loss
79
+ use_distance_weighted_loss: False
80
+ xing_loss_weight: 0.01
81
+ tog: # for refinement
82
+ reinit: True # if False, use fg params to init content
83
+ num_iter: 1000
84
+ # optim
85
+ lr_schedule: False # enable lr_scheduler or not
86
+ # loss
87
+ bg_lam: 0
88
+ fg_lam: 1
89
+ xing_loss_weight: 0
90
+
91
+ # VPSD primitives
92
+ num_paths: 128 # number of strokes
93
+ trainable_bg: False # set the background to be trainable
94
+ width: 6 # stroke width
95
+ num_segments: 4
96
+ segment_init: 'circle' # 'random'
97
+ radius: 20
98
+ coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
99
+ grid: 50 # divide the canvas into n grids
100
+ path_reinit: # reinitializing paths
101
+ use: True
102
+ freq: 100 # every 50 iterations
103
+ stop_step: 1000 # for VPSD fine-tuning
104
+ opacity_threshold: 0.05
105
+ area_threshold: 64
106
+
107
+ # lr and optim
108
+ vpsd_stage_optim:
109
+ point: 1
110
+ width: 0.1
111
+ color: 0.01
112
+ bg: 0.01
113
+ lr_schedule: True # use lr_scheduler
114
+ optim:
115
+ name: 'adam'
116
+ betas: [ 0.9, 0.9 ]
117
+ eps: 1e-6
118
+ schedule:
119
+ name: 'cosine'
120
+ warmup_steps: 10
121
+ warmup_start_lr: 0.02
122
+ warmup_end_lr: 0.9
123
+ cosine_end_lr: 0.4
124
+
125
+ # stable diffusion in VPSD stage
126
+ vpsd_model_cfg:
127
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
128
+ ldm_speed_up: False
129
+ enable_xformers: True
130
+ gradient_checkpoint: False
131
+ cpu_offload: True
132
+ num_inference_steps: 100
133
+ guidance_scale: 7.5 # sdxl default 5.0
134
+ lora_path: ~
135
+
136
+ # VPSD setting
137
+ vpsd:
138
+ type: 'vpsd'
139
+ n_particle: 6 # 4, 8, 16
140
+ vsd_n_particle: 4 # the batch size of particles
141
+ particle_aug: False # do data enhancement for the input particles
142
+ num_iter: 2000 # total iterations
143
+ guidance_scale: 7.5 # CFG value
144
+ grad_scale: 1.0 # increase or decrease the gradient
145
+ grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
146
+ t_range: [ 0.02, 0.98 ]
147
+ # 'randint': random time steps, this may have a more authentic style.
148
+ # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
149
+ t_schedule: 'randint' # or 'randint'
150
+ # phi model config
151
+ phi_single: False # if False new an unet model to estimate noise
152
+ phi_model: 'lora' # 'lora', 'unet_simple'
153
+ use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
154
+ lora_attn_scale: 1.0 # the scale of the attn based lora layer
155
+ phi_guidance_scale: 1.0
156
+ phi_t: False # different t for phi fine-tuning
157
+ phi_update_step: 1 # enable multi-update phi model or not
158
+ phi_lr: 0.0001 # learning rate of phi model
159
+ phi_scheduler: 'ddim'
160
+ phi_n_particle: 2 # the batch size of phi_model
161
+ # ReFL config
162
+ phi_ReFL: False # enable reward feed back learning
163
+ n_phi_sample: 1 # number of samples used in ReFL
164
+ phi_sample_step: 200 # the phi log step
165
+ phi_infer_step: 50 # the phi num_inference_steps
166
+ # phi model optim
167
+ phi_optim:
168
+ name: 'adamw'
169
+ betas: [ 0.9, 0.999 ]
170
+ eps: 1e-8
171
+ weight_decay: ~ # 1e-5
172
+ # phi model lr learning schedule
173
+ phi_schedule:
174
+ use: False
175
+ name: 'cosine'
176
+ warmup_steps: 50
177
+ warmup_start_lr: 0.00001
178
+ warmup_end_lr: 0.0001
179
+ total_step: 800
180
+ cosine_end_lr: 0.0001
181
+
182
+ # reward model
183
+ reward_path: './checkpoint/ImageReward'
184
+
185
+ # xing loss for closed-form paths
186
+ xing_loss:
187
+ use: False
188
+ weight: 0.01
conf/x/lowpoly.yaml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 600 # canvas size
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ color_init: 'rand' # if skip_live=True, then use color_init to init target_img
4
+ style: "low-poly" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
5
+
6
+ # stable diffusion in SIVE stage
7
+ sive_model_cfg:
8
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
9
+ ldm_speed_up: False
10
+ enable_xformers: True
11
+ gradient_checkpoint: False
12
+ cpu_offload: True
13
+ num_inference_steps: 100
14
+ guidance_scale: 7.5 # sdxl default 5.0
15
+ lora_path: ~
16
+
17
+ # lr and optim
18
+ sive_stage_optim:
19
+ point: 1 # control points
20
+ width: 0.1 # stroke width
21
+ color: 0.01 # fill color and stroke color
22
+ bg: 0.01 # bg in render_warp
23
+ optim:
24
+ name: 'adam'
25
+ betas: [ 0.9, 0.9 ]
26
+ eps: 1e-6
27
+ schedule:
28
+ name: 'linear'
29
+ keep_ratio: 0.2
30
+ decay_ratio: 0.4
31
+
32
+ # SIVE rendering
33
+ sive:
34
+ attn_cfg: # init content via attn
35
+ cross_attn_res: 16
36
+ self_attn_res: 32
37
+ max_com: 20
38
+ mean_comp: False
39
+ comp_idx: 0
40
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
41
+ bg:
42
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
43
+ num_iter: 10
44
+ num_paths: 256
45
+ path_schedule: 'repeat' # 'repeat', 'list'
46
+ schedule_each: 128
47
+ width: 3 # sketch stroke width
48
+ num_segments: 4
49
+ segment_init: 'circle' # 'random'
50
+ radius: 20
51
+ coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
52
+ grid: 20
53
+ # optim
54
+ lr_schedule: True
55
+ optim_bg: False # train background
56
+ use_attn_init: True
57
+ softmax_tau: 0.3 # temperature of softmax
58
+ # loss
59
+ use_distance_weighted_loss: False
60
+ xing_loss_weight: 0.001
61
+ fg:
62
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
63
+ num_iter: 10
64
+ num_paths: 256 # number of strokes
65
+ path_schedule: 'repeat' # 'repeat', 'list'
66
+ schedule_each: 128
67
+ width: 3 # sketch stroke width
68
+ num_segments: 4
69
+ segment_init: 'circle' # 'random'
70
+ radius: 15
71
+ coord_init: 'random' # 'random', 'naive', place the first control point
72
+ grid: 20
73
+ # optim
74
+ lr_schedule: False
75
+ optim_bg: False # train background
76
+ use_attn_init: True
77
+ softmax_tau: 0.3 # temperature of softmax
78
+ # loss
79
+ use_distance_weighted_loss: False
80
+ xing_loss_weight: 0.01
81
+ tog: # for refinement
82
+ reinit: True # if False, use fg params to init content
83
+ num_iter: 1000
84
+ # optim
85
+ lr_schedule: False # enable lr_scheduler or not
86
+ # loss
87
+ bg_lam: 0
88
+ fg_lam: 1
89
+ xing_loss_weight: 0
90
+
91
+ # VPSD primitives
92
+ num_paths: 512 # number of strokes
93
+ trainable_bg: False # set the background to be trainable
94
+ width: 3 # stroke width
95
+ num_segments: 4
96
+ segment_init: 'circle' # 'random'
97
+ radius: 20
98
+ coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
99
+ grid: 30 # divide the canvas into n grids
100
+ path_reinit: # reinitializing paths
101
+ use: True
102
+ freq: 100 # every 50 iterations
103
+ stop_step: 1000 # for VPSD fine-tuning
104
+ opacity_threshold: 0.05
105
+ area_threshold: 64
106
+
107
+ # lr and optim
108
+ vpsd_stage_optim:
109
+ point: 1
110
+ width: 0.1
111
+ color: 0.01
112
+ bg: 0.01
113
+ lr_schedule: True # use lr_scheduler
114
+ optim:
115
+ name: 'adam'
116
+ betas: [ 0.9, 0.9 ]
117
+ eps: 1e-6
118
+ schedule:
119
+ name: 'cosine'
120
+ warmup_steps: 10
121
+ warmup_start_lr: 0.02
122
+ warmup_end_lr: 0.9
123
+ cosine_end_lr: 0.4
124
+
125
+ # stable diffusion in VPSD stage
126
+ vpsd_model_cfg:
127
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
128
+ ldm_speed_up: False
129
+ enable_xformers: True
130
+ gradient_checkpoint: False
131
+ cpu_offload: True
132
+ num_inference_steps: 100
133
+ guidance_scale: 7.5 # sdxl default 5.0
134
+ lora_path: ~
135
+
136
+ # VPSD setting
137
+ vpsd:
138
+ type: 'vpsd'
139
+ n_particle: 6 # 4, 8, 16
140
+ vsd_n_particle: 4 # the batch size of particles
141
+ particle_aug: False # do data enhancement for the input particles
142
+ num_iter: 1500 # total iterations
143
+ guidance_scale: 7.5 # CFG value
144
+ grad_scale: 1.0 # increase or decrease the gradient
145
+ grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
146
+ t_range: [ 0.02, 0.98 ]
147
+ # 'randint': random time steps, this may have a more authentic style.
148
+ # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
149
+ t_schedule: 'max_0.5_1500' # or 'randint'
150
+ # phi model config
151
+ phi_single: False # if False new an unet model to estimate noise
152
+ phi_model: 'lora' # 'lora', 'unet_simple'
153
+ use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
154
+ lora_attn_scale: 1.0 # the scale of the attn based lora layer
155
+ phi_guidance_scale: 1.0
156
+ phi_t: False # different t for phi fine-tuning
157
+ phi_update_step: 1 # enable multi-update phi model or not
158
+ phi_lr: 0.0001 # learning rate of phi model
159
+ phi_scheduler: 'ddim'
160
+ phi_n_particle: 2 # the batch size of phi_model
161
+ # ReFL config
162
+ phi_ReFL: False # enable reward feed back learning
163
+ n_phi_sample: 1 # number of samples used in ReFL
164
+ phi_sample_step: 200 # the phi log step
165
+ phi_infer_step: 50 # the phi num_inference_steps
166
+ # phi model optim
167
+ phi_optim:
168
+ name: 'adamw'
169
+ betas: [ 0.9, 0.999 ]
170
+ eps: 1e-8
171
+ weight_decay: ~ # 1e-5
172
+ # phi model lr learning schedule
173
+ phi_schedule:
174
+ use: False
175
+ name: 'cosine'
176
+ warmup_steps: 50
177
+ warmup_start_lr: 0.00001
178
+ warmup_end_lr: 0.0001
179
+ total_step: 800
180
+ cosine_end_lr: 0.0001
181
+
182
+ # reward model
183
+ reward_path: './checkpoint/ImageReward'
184
+
185
+ # xing loss for closed-form paths
186
+ xing_loss:
187
+ use: False
188
+ weight: 0.01
conf/x/painting.yaml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 600 # canvas size
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ color_init: 'rand' # if skip_live=True, then use color_init to init target_img
4
+ style: "painting" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
5
+
6
+ # stable diffusion in SIVE stage
7
+ sive_model_cfg:
8
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
9
+ ldm_speed_up: False
10
+ enable_xformers: True
11
+ gradient_checkpoint: False
12
+ cpu_offload: True
13
+ num_inference_steps: 100
14
+ guidance_scale: 7.5 # sdxl default 5.0
15
+ lora_path: ~
16
+
17
+ # lr and optim
18
+ sive_stage_optim:
19
+ point: 1 # control points
20
+ width: 0.1 # stroke width
21
+ color: 0.01 # fill color and stroke color
22
+ bg: 0.01 # bg in render_warp
23
+ optim:
24
+ name: 'adam'
25
+ betas: [ 0.9, 0.9 ]
26
+ eps: 1e-6
27
+ schedule:
28
+ name: 'linear'
29
+ keep_ratio: 0.2
30
+ decay_ratio: 0.4
31
+
32
+ # SIVE rendering
33
+ sive:
34
+ attn_cfg: # init content via attn
35
+ cross_attn_res: 16
36
+ self_attn_res: 32
37
+ max_com: 20
38
+ mean_comp: False
39
+ comp_idx: 0
40
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
41
+ bg:
42
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
43
+ num_iter: 10
44
+ num_paths: 256
45
+ path_schedule: 'repeat' # 'repeat', 'list'
46
+ schedule_each: 128
47
+ width: 3 # sketch stroke width
48
+ num_segments: 4
49
+ segment_init: 'circle' # 'random'
50
+ radius: 20
51
+ coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
52
+ grid: 20
53
+ # optim
54
+ lr_schedule: True
55
+ optim_bg: False # train background
56
+ use_attn_init: True
57
+ softmax_tau: 0.3 # temperature of softmax
58
+ # loss
59
+ use_distance_weighted_loss: False
60
+ xing_loss_weight: 0.001
61
+ fg:
62
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
63
+ num_iter: 10
64
+ num_paths: 256 # number of strokes
65
+ path_schedule: 'repeat' # 'repeat', 'list'
66
+ schedule_each: 128
67
+ width: 3 # sketch stroke width
68
+ num_segments: 4
69
+ segment_init: 'circle' # 'random'
70
+ radius: 15
71
+ coord_init: 'random' # 'random', 'naive', place the first control point
72
+ grid: 20
73
+ # optim
74
+ lr_schedule: False
75
+ optim_bg: False # train background
76
+ use_attn_init: True
77
+ softmax_tau: 0.3 # temperature of softmax
78
+ # loss
79
+ use_distance_weighted_loss: False
80
+ xing_loss_weight: 0.01
81
+ tog: # for refinement
82
+ reinit: True # if False, use fg params to init content
83
+ num_iter: 1000
84
+ # optim
85
+ lr_schedule: False # enable lr_scheduler or not
86
+ # loss
87
+ bg_lam: 0
88
+ fg_lam: 1
89
+ xing_loss_weight: 0
90
+
91
+ # VPSD primitives
92
+ num_paths: 1500 # number of strokes
93
+ trainable_bg: False # set the background to be trainable
94
+ width: 3 # stroke width
95
+ num_segments: 4
96
+ segment_init: 'circle' # 'random'
97
+ radius: 20
98
+ coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
99
+ grid: 50 # divide the canvas into n grids
100
+ path_reinit: # reinitializing paths
101
+ use: True
102
+ freq: 100 # every 50 iterations
103
+ stop_step: 1000 # for VPSD fine-tuning
104
+ opacity_threshold: 0.05
105
+ area_threshold: 64
106
+
107
+ # lr and optim
108
+ vpsd_stage_optim:
109
+ point: 1
110
+ width: 0.1
111
+ color: 0.01
112
+ bg: 0.01
113
+ lr_schedule: True # use lr_scheduler
114
+ optim:
115
+ name: 'adam'
116
+ betas: [ 0.9, 0.9 ]
117
+ eps: 1e-6
118
+ schedule:
119
+ name: 'cosine'
120
+ warmup_steps: 10
121
+ warmup_start_lr: 0.02
122
+ warmup_end_lr: 0.9
123
+ cosine_end_lr: 0.4
124
+
125
+ # stable diffusion in VPSD stage
126
+ vpsd_model_cfg:
127
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
128
+ ldm_speed_up: False
129
+ enable_xformers: True
130
+ gradient_checkpoint: False
131
+ cpu_offload: True
132
+ num_inference_steps: 100
133
+ guidance_scale: 7.5 # sdxl default 5.0
134
+ lora_path: ~
135
+
136
+ # VPSD setting
137
+ vpsd:
138
+ type: 'vpsd'
139
+ n_particle: 6 # 4, 8, 16
140
+ vsd_n_particle: 4 # the batch size of particles
141
+ particle_aug: False # do data enhancement for the input particles
142
+ num_iter: 2000 # total iterations
143
+ guidance_scale: 7.5 # CFG value
144
+ grad_scale: 1.0 # increase or decrease the gradient
145
+ grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
146
+ t_range: [ 0.02, 0.98 ]
147
+ # 'randint': random time steps, this may have a more authentic style.
148
+ # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
149
+ t_schedule: 'randint' # or 'randint'
150
+ # phi model config
151
+ phi_single: False # if False new an unet model to estimate noise
152
+ phi_model: 'lora' # 'lora', 'unet_simple'
153
+ use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
154
+ lora_attn_scale: 1.0 # the scale of the attn based lora layer
155
+ phi_guidance_scale: 1.0
156
+ phi_t: False # different t for phi fine-tuning
157
+ phi_update_step: 1 # enable multi-update phi model or not
158
+ phi_lr: 0.0001 # learning rate of phi model
159
+ phi_scheduler: 'ddim'
160
+ phi_n_particle: 2 # the batch size of phi_model
161
+ # ReFL config
162
+ phi_ReFL: False # enable reward feed back learning
163
+ n_phi_sample: 1 # number of samples used in ReFL
164
+ phi_sample_step: 200 # the phi log step
165
+ phi_infer_step: 50 # the phi num_inference_steps
166
+ # phi model optim
167
+ phi_optim:
168
+ name: 'adamw'
169
+ betas: [ 0.9, 0.999 ]
170
+ eps: 1e-8
171
+ weight_decay: ~ # 1e-5
172
+ # phi model lr learning schedule
173
+ phi_schedule:
174
+ use: False
175
+ name: 'cosine'
176
+ warmup_steps: 50
177
+ warmup_start_lr: 0.00001
178
+ warmup_end_lr: 0.0001
179
+ total_step: 800
180
+ cosine_end_lr: 0.0001
181
+
182
+ # reward model
183
+ reward_path: './checkpoint/ImageReward'
184
+
185
+ # xing loss for closed-form paths
186
+ xing_loss:
187
+ use: False
188
+ weight: 0.01
conf/x/pixelart.yaml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 600 # canvas size
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ color_init: 'rand' # if skip_live=True, then use color_init to init target_img
4
+ style: "pixelart" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
5
+
6
+ # stable diffusion in SIVE stage
7
+ sive_model_cfg:
8
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
9
+ ldm_speed_up: False
10
+ enable_xformers: True
11
+ gradient_checkpoint: False
12
+ cpu_offload: True
13
+ num_inference_steps: 100
14
+ guidance_scale: 7.5 # sdxl default 5.0
15
+ lora_path: ~
16
+
17
+ # lr and optim
18
+ sive_stage_optim:
19
+ point: 1 # control points
20
+ width: 0.1 # stroke width
21
+ color: 0.01 # fill color and stroke color
22
+ bg: 0.01 # bg in render_warp
23
+ optim:
24
+ name: 'adam'
25
+ betas: [ 0.9, 0.9 ]
26
+ eps: 1e-6
27
+ schedule:
28
+ name: 'linear'
29
+ keep_ratio: 0.2
30
+ decay_ratio: 0.4
31
+
32
+ # SIVE rendering
33
+ sive:
34
+ attn_cfg: # init content via attn
35
+ cross_attn_res: 16
36
+ self_attn_res: 32
37
+ max_com: 20
38
+ mean_comp: False
39
+ comp_idx: 0
40
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
41
+ bg:
42
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
43
+ num_iter: 10
44
+ num_paths: 256
45
+ path_schedule: 'repeat' # 'repeat', 'list'
46
+ schedule_each: 128
47
+ width: 3 # sketch stroke width
48
+ num_segments: 4
49
+ segment_init: 'circle' # 'random'
50
+ radius: 20
51
+ coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
52
+ grid: 20
53
+ # optim
54
+ lr_schedule: True
55
+ optim_bg: False # train background
56
+ use_attn_init: True
57
+ softmax_tau: 0.3 # temperature of softmax
58
+ # loss
59
+ use_distance_weighted_loss: False
60
+ xing_loss_weight: 0.001
61
+ fg:
62
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
63
+ num_iter: 10
64
+ num_paths: 256 # number of strokes
65
+ path_schedule: 'repeat' # 'repeat', 'list'
66
+ schedule_each: 128
67
+ width: 3 # sketch stroke width
68
+ num_segments: 4
69
+ segment_init: 'circle' # 'random'
70
+ radius: 15
71
+ coord_init: 'random' # 'random', 'naive', place the first control point
72
+ grid: 20
73
+ # optim
74
+ lr_schedule: False
75
+ optim_bg: False # train background
76
+ use_attn_init: True
77
+ softmax_tau: 0.3 # temperature of softmax
78
+ # loss
79
+ use_distance_weighted_loss: False
80
+ xing_loss_weight: 0.01
81
+ tog: # for refinement
82
+ reinit: True # if False, use fg params to init content
83
+ num_iter: 1000
84
+ # optim
85
+ lr_schedule: False # enable lr_scheduler or not
86
+ # loss
87
+ bg_lam: 0
88
+ fg_lam: 1
89
+ xing_loss_weight: 0
90
+
91
+ # VPSD primitives
92
+ num_paths: 512 # number of strokes
93
+ trainable_bg: False # set the background to be trainable
94
+ width: 3 # stroke width
95
+ num_segments: 4
96
+ segment_init: 'circle' # 'random'
97
+ radius: 20
98
+ coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
99
+ grid: 50 # divide the canvas into n grids
100
+ path_reinit: # reinitializing paths
101
+ use: True
102
+ freq: 100 # every 50 iterations
103
+ stop_step: 1000 # for VPSD fine-tuning
104
+ opacity_threshold: 0.05
105
+ area_threshold: 64
106
+
107
+ # lr and optim
108
+ vpsd_stage_optim:
109
+ point: 1
110
+ width: 0.1
111
+ color: 0.01
112
+ bg: 0.01
113
+ lr_schedule: True # use lr_scheduler
114
+ optim:
115
+ name: 'adam'
116
+ betas: [ 0.9, 0.9 ]
117
+ eps: 1e-6
118
+ schedule:
119
+ name: 'cosine'
120
+ warmup_steps: 10
121
+ warmup_start_lr: 0.02
122
+ warmup_end_lr: 0.9
123
+ cosine_end_lr: 0.4
124
+
125
+ # stable diffusion in VPSD stage
126
+ vpsd_model_cfg:
127
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
128
+ ldm_speed_up: False
129
+ enable_xformers: True
130
+ gradient_checkpoint: False
131
+ cpu_offload: True
132
+ num_inference_steps: 100
133
+ guidance_scale: 7.5 # sdxl default 5.0
134
+ lora_path: ~
135
+
136
+ # VPSD setting
137
+ vpsd:
138
+ type: 'vpsd'
139
+ n_particle: 6 # 4, 8, 16
140
+ vsd_n_particle: 4 # the batch size of particles
141
+ particle_aug: False # do data enhancement for the input particles
142
+ num_iter: 1000 # total iterations
143
+ guidance_scale: 7.5 # CFG value
144
+ grad_scale: 1.0 # increase or decrease the gradient
145
+ grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
146
+ t_range: [ 0.02, 0.98 ]
147
+ # 'randint': random time steps, this may have a more authentic style.
148
+ # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
149
+ t_schedule: 'max_0.5_1500' # or 'randint'
150
+ # phi model config
151
+ phi_single: False # if False new an unet model to estimate noise
152
+ phi_model: 'lora' # 'lora', 'unet_simple'
153
+ use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
154
+ lora_attn_scale: 1.0 # the scale of the attn based lora layer
155
+ phi_guidance_scale: 1.0
156
+ phi_t: False # different t for phi fine-tuning
157
+ phi_update_step: 1 # enable multi-update phi model or not
158
+ phi_lr: 0.0001 # learning rate of phi model
159
+ phi_scheduler: 'ddim'
160
+ phi_n_particle: 2 # the batch size of phi_model
161
+ # ReFL config
162
+ phi_ReFL: False # enable reward feed back learning
163
+ n_phi_sample: 1 # number of samples used in ReFL
164
+ phi_sample_step: 200 # the phi log step
165
+ phi_infer_step: 50 # the phi num_inference_steps
166
+ # phi model optim
167
+ phi_optim:
168
+ name: 'adamw'
169
+ betas: [ 0.9, 0.999 ]
170
+ eps: 1e-8
171
+ weight_decay: ~ # 1e-5
172
+ # phi model lr learning schedule
173
+ phi_schedule:
174
+ use: False
175
+ name: 'cosine'
176
+ warmup_steps: 50
177
+ warmup_start_lr: 0.00001
178
+ warmup_end_lr: 0.0001
179
+ total_step: 800
180
+ cosine_end_lr: 0.0001
181
+
182
+ # reward model
183
+ reward_path: './checkpoint/ImageReward'
184
+
185
+ # xing loss for closed-form paths
186
+ xing_loss:
187
+ use: False
188
+ weight: 0.01
conf/x/sketch.yaml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_size: 600 # canvas size
2
+ path_svg: ~ # if you want to load a svg file and train from it
3
+ color_init: 'rand' # if skip_live=True, then use color_init to init target_img
4
+ style: "sketch" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink"
5
+
6
+ # stable diffusion in SIVE stage
7
+ sive_model_cfg:
8
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
9
+ ldm_speed_up: False
10
+ enable_xformers: True
11
+ gradient_checkpoint: False
12
+ cpu_offload: True
13
+ num_inference_steps: 100
14
+ guidance_scale: 7.5 # sdxl default 5.0
15
+ lora_path: ~
16
+
17
+ # lr and optim
18
+ sive_stage_optim:
19
+ point: 1 # control points
20
+ width: 0.1 # stroke width
21
+ color: 0.01 # fill color and stroke color
22
+ bg: 0.01 # bg in render_warp
23
+ optim:
24
+ name: 'adam'
25
+ betas: [ 0.9, 0.9 ]
26
+ eps: 1e-6
27
+ schedule:
28
+ name: 'linear'
29
+ keep_ratio: 0.2
30
+ decay_ratio: 0.4
31
+
32
+ # SIVE rendering
33
+ sive:
34
+ attn_cfg: # init content via attn
35
+ cross_attn_res: 16
36
+ self_attn_res: 32
37
+ max_com: 20
38
+ mean_comp: False
39
+ comp_idx: 0
40
+ attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn
41
+ bg:
42
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
43
+ num_iter: 10
44
+ num_paths: 256
45
+ path_schedule: 'repeat' # 'repeat', 'list'
46
+ schedule_each: 128
47
+ width: 3 # sketch stroke width
48
+ num_segments: 4
49
+ segment_init: 'circle' # 'random'
50
+ radius: 20
51
+ coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point
52
+ grid: 20
53
+ # optim
54
+ lr_schedule: True
55
+ optim_bg: False # train background
56
+ use_attn_init: True
57
+ softmax_tau: 0.3 # temperature of softmax
58
+ # loss
59
+ use_distance_weighted_loss: False
60
+ xing_loss_weight: 0.001
61
+ fg:
62
+ style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink'
63
+ num_iter: 10
64
+ num_paths: 256 # number of strokes
65
+ path_schedule: 'repeat' # 'repeat', 'list'
66
+ schedule_each: 128
67
+ width: 3 # sketch stroke width
68
+ num_segments: 4
69
+ segment_init: 'circle' # 'random'
70
+ radius: 15
71
+ coord_init: 'random' # 'random', 'naive', place the first control point
72
+ grid: 20
73
+ # optim
74
+ lr_schedule: False
75
+ optim_bg: False # train background
76
+ use_attn_init: True
77
+ softmax_tau: 0.3 # temperature of softmax
78
+ # loss
79
+ use_distance_weighted_loss: False
80
+ xing_loss_weight: 0.01
81
+ tog: # for refinement
82
+ reinit: True # if False, use fg params to init content
83
+ num_iter: 1000
84
+ # optim
85
+ lr_schedule: False # enable lr_scheduler or not
86
+ # loss
87
+ bg_lam: 0
88
+ fg_lam: 1
89
+ xing_loss_weight: 0
90
+
91
+ # VPSD primitives
92
+ num_paths: 128 # number of strokes
93
+ trainable_bg: False # set the background to be trainable
94
+ width: 3 # stroke width
95
+ num_segments: 4
96
+ segment_init: 'circle' # 'random'
97
+ radius: 20
98
+ coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point
99
+ grid: 50 # divide the canvas into n grids
100
+ path_reinit: # reinitializing paths
101
+ use: True
102
+ freq: 100 # every 50 iterations
103
+ stop_step: 1000 # for VPSD fine-tuning
104
+ opacity_threshold: 0.05
105
+ area_threshold: 64
106
+
107
+ # lr and optim
108
+ vpsd_stage_optim:
109
+ point: 1
110
+ width: 0.1
111
+ color: 0.01
112
+ bg: 0.01
113
+ lr_schedule: True # use lr_scheduler
114
+ optim:
115
+ name: 'adam'
116
+ betas: [ 0.9, 0.9 ]
117
+ eps: 1e-6
118
+ schedule:
119
+ name: 'cosine'
120
+ warmup_steps: 10
121
+ warmup_start_lr: 0.02
122
+ warmup_end_lr: 0.9
123
+ cosine_end_lr: 0.4
124
+
125
+ # stable diffusion in VPSD stage
126
+ vpsd_model_cfg:
127
+ model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl
128
+ ldm_speed_up: False
129
+ enable_xformers: True
130
+ gradient_checkpoint: False
131
+ cpu_offload: True
132
+ num_inference_steps: 100
133
+ guidance_scale: 7.5 # sdxl default 5.0
134
+ lora_path: ~
135
+
136
+ # VPSD setting
137
+ vpsd:
138
+ type: 'vpsd'
139
+ n_particle: 6 # 4, 8, 16
140
+ vsd_n_particle: 4 # the batch size of particles
141
+ particle_aug: False # do data enhancement for the input particles
142
+ num_iter: 2000 # total iterations
143
+ guidance_scale: 7.5 # CFG value
144
+ grad_scale: 1.0 # increase or decrease the gradient
145
+ grad_clip_val: ~ # eg: 10, clip the gradient of VPSD
146
+ t_range: [ 0.02, 0.98 ]
147
+ # 'randint': random time steps, this may have a more authentic style.
148
+ # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results.
149
+ t_schedule: 'randint' # or 'randint'
150
+ # phi model config
151
+ phi_single: False # if False new an unet model to estimate noise
152
+ phi_model: 'lora' # 'lora', 'unet_simple'
153
+ use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not
154
+ lora_attn_scale: 1.0 # the scale of the attn based lora layer
155
+ phi_guidance_scale: 1.0
156
+ phi_t: False # different t for phi fine-tuning
157
+ phi_update_step: 1 # enable multi-update phi model or not
158
+ phi_lr: 0.0001 # learning rate of phi model
159
+ phi_scheduler: 'ddim'
160
+ phi_n_particle: 2 # the batch size of phi_model
161
+ # ReFL config
162
+ phi_ReFL: False # enable reward feed back learning
163
+ n_phi_sample: 1 # number of samples used in ReFL
164
+ phi_sample_step: 200 # the phi log step
165
+ phi_infer_step: 50 # the phi num_inference_steps
166
+ # phi model optim
167
+ phi_optim:
168
+ name: 'adamw'
169
+ betas: [ 0.9, 0.999 ]
170
+ eps: 1e-8
171
+ weight_decay: ~ # 1e-5
172
+ # phi model lr learning schedule
173
+ phi_schedule:
174
+ use: False
175
+ name: 'cosine'
176
+ warmup_steps: 50
177
+ warmup_start_lr: 0.00001
178
+ warmup_end_lr: 0.0001
179
+ total_step: 800
180
+ cosine_end_lr: 0.0001
181
+
182
+ # reward model
183
+ reward_path: './checkpoint/ImageReward'
184
+
185
+ # xing loss for closed-form paths
186
+ xing_loss:
187
+ use: False
188
+ weight: 0.01
svgdreamer.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: ximing xing
3
+ # Description: the main func of this project.
4
+ # Copyright (c) 2023, XiMing Xing.
5
+
6
+ import os
7
+ import sys
8
+ from functools import partial
9
+
10
+ from accelerate.utils import set_seed
11
+ import hydra
12
+ import omegaconf
13
+
14
+ sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
15
+
16
+ from svgdreamer.utils import render_batch_wrap, get_seed_range
17
+ from svgdreamer.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline
18
+
19
+
20
+ @hydra.main(version_base=None, config_path="conf", config_name='config')
21
+ def main(cfg: omegaconf.DictConfig):
22
+ """
23
+ The project configuration is stored in './conf/config.yaml’
24
+ And style configurations are stored in './conf/x/iconographic.yaml’
25
+ """
26
+
27
+ # set seed
28
+ set_seed(cfg.seed)
29
+ seed_range = get_seed_range(cfg.srange) if cfg.multirun else None
30
+
31
+ # render function
32
+ render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range)
33
+
34
+ if not cfg.multirun: # generate SVG multiple times
35
+ pipe = SVGDreamerPipeline(cfg)
36
+ pipe.painterly_rendering(cfg.prompt)
37
+ else: # generate many SVG at once
38
+ render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None)
39
+
40
+
41
+ if __name__ == '__main__':
42
+ main()
svgdreamer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: ximing
3
+ # Copyright (c) 2023, XiMing Xing.
4
+ # License: MIT
5
+
6
+ __version__ = "1.0"
svgdreamer/diffusers_warp/__init__.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+ from typing import AnyStr
6
+ import pathlib
7
+ from collections import OrderedDict
8
+ from packaging import version
9
+
10
+ import torch
11
+ from diffusers import StableDiffusionPipeline, SchedulerMixin
12
+ from diffusers import UNet2DConditionModel
13
+ from diffusers.utils import is_torch_version, is_xformers_available
14
+
15
+ DiffusersModels = OrderedDict({
16
+ "sd14": "CompVis/stable-diffusion-v1-4", # resolution: 512
17
+ "sd15": "runwayml/stable-diffusion-v1-5", # resolution: 512
18
+ "sd21b": "stabilityai/stable-diffusion-2-1-base", # resolution: 512
19
+ "sd21": "stabilityai/stable-diffusion-2-1", # resolution: 768
20
+ "sdxl": "stabilityai/stable-diffusion-xl-base-1.0", # resolution: 1024
21
+ })
22
+
23
+ # default resolution
24
+ _model2resolution = {
25
+ "sd14": 512,
26
+ "sd15": 512,
27
+ "sd21b": 512,
28
+ "sd21": 768,
29
+ "sdxl": 1024,
30
+ }
31
+
32
+
33
+ def model2res(model_id: str):
34
+ return _model2resolution.get(model_id, 512)
35
+
36
+
37
+ def init_StableDiffusion_pipeline(model_id: AnyStr,
38
+ custom_pipeline: StableDiffusionPipeline,
39
+ custom_scheduler: SchedulerMixin = None,
40
+ device: torch.device = "cuda",
41
+ torch_dtype: torch.dtype = torch.float32,
42
+ local_files_only: bool = True,
43
+ force_download: bool = False,
44
+ resume_download: bool = False,
45
+ ldm_speed_up: bool = False,
46
+ enable_xformers: bool = True,
47
+ gradient_checkpoint: bool = False,
48
+ cpu_offload: bool = False,
49
+ vae_slicing: bool = False,
50
+ lora_path: AnyStr = None,
51
+ unet_path: AnyStr = None) -> StableDiffusionPipeline:
52
+ """
53
+ A tool for initial diffusers pipeline.
54
+
55
+ Args:
56
+ model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
57
+ custom_pipeline: any StableDiffusionPipeline pipeline
58
+ custom_scheduler: any scheduler
59
+ device: set device
60
+ torch_dtype: data type
61
+ local_files_only: prohibited download model
62
+ force_download: forced download model
63
+ resume_download: re-download model
64
+ ldm_speed_up: use the `torch.compile` api to speed up unet
65
+ enable_xformers: enable memory efficient attention from [xFormers]
66
+ gradient_checkpoint: activates gradient checkpointing for the current model
67
+ cpu_offload: enable sequential cpu offload
68
+ vae_slicing: enable sliced VAE decoding
69
+ lora_path: load LoRA checkpoint
70
+ unet_path: load unet checkpoint
71
+
72
+ Returns:
73
+ diffusers.StableDiffusionPipeline
74
+ """
75
+
76
+ # get model id
77
+ model_id = DiffusersModels.get(model_id, model_id)
78
+
79
+ # process diffusion model
80
+ if custom_scheduler is not None:
81
+ pipeline = custom_pipeline.from_pretrained(
82
+ model_id,
83
+ torch_dtype=torch_dtype,
84
+ local_files_only=local_files_only,
85
+ force_download=force_download,
86
+ resume_download=resume_download,
87
+ scheduler=custom_scheduler.from_pretrained(model_id,
88
+ subfolder="scheduler",
89
+ local_files_only=local_files_only,
90
+ force_download=force_download,
91
+ resume_download=resume_download)
92
+ ).to(device)
93
+ else:
94
+ pipeline = custom_pipeline.from_pretrained(
95
+ model_id,
96
+ torch_dtype=torch_dtype,
97
+ local_files_only=local_files_only,
98
+ force_download=force_download,
99
+ resume_download=resume_download,
100
+ ).to(device)
101
+
102
+ print(f"load diffusers pipeline: {model_id}")
103
+
104
+ # process unet model if exist
105
+ if unet_path is not None and pathlib.Path(unet_path).exists():
106
+ print(f"=> load u-net from {unet_path}")
107
+ pipeline.unet.from_pretrained(model_id, subfolder="unet")
108
+
109
+ # process lora layers if exist
110
+ if lora_path is not None and pathlib.Path(lora_path).exists():
111
+ pipeline.unet.load_attn_procs(lora_path)
112
+ print(f"=> load lora layers into U-Net from {lora_path} ...")
113
+
114
+ # torch.compile
115
+ if ldm_speed_up:
116
+ if is_torch_version(">=", "2.0.0"):
117
+ pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
118
+ print(f"=> enable torch.compile on U-Net")
119
+ else:
120
+ print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")
121
+
122
+ # Meta xformers
123
+ if enable_xformers:
124
+ if is_xformers_available():
125
+ import xformers
126
+
127
+ xformers_version = version.parse(xformers.__version__)
128
+ if xformers_version == version.parse("0.0.16"):
129
+ print(
130
+ "xFormers 0.0.16 cannot be used for training in some GPUs. "
131
+ "If you observe problems during training, please update xFormers to at least 0.0.17. "
132
+ "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
133
+ )
134
+ print(f"=> enable xformers")
135
+ pipeline.unet.enable_xformers_memory_efficient_attention()
136
+ else:
137
+ print(f"=> warning: xformers is not available.")
138
+
139
+ # gradient checkpointing
140
+ if gradient_checkpoint:
141
+ # if pipeline.unet.is_gradient_checkpointing:
142
+ if True:
143
+ print(f"=> enable gradient checkpointing")
144
+ pipeline.unet.enable_gradient_checkpointing()
145
+ else:
146
+ print("=> waring: gradient checkpointing is not activated for this model.")
147
+
148
+ if cpu_offload:
149
+ pipeline.enable_sequential_cpu_offload()
150
+
151
+ if vae_slicing:
152
+ pipeline.enable_vae_slicing()
153
+
154
+ print(pipeline.scheduler)
155
+ return pipeline
156
+
157
+
158
+ def init_diffusers_unet(model_id: AnyStr,
159
+ device: torch.device = "cuda",
160
+ torch_dtype: torch.dtype = torch.float32,
161
+ local_files_only: bool = True,
162
+ force_download: bool = False,
163
+ resume_download: bool = False,
164
+ ldm_speed_up: bool = False,
165
+ enable_xformers: bool = True,
166
+ gradient_checkpoint: bool = False,
167
+ lora_path: AnyStr = None,
168
+ unet_path: AnyStr = None):
169
+ """
170
+ A tool for initial diffusers UNet model.
171
+
172
+ Args:
173
+ model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
174
+ device: set device
175
+ torch_dtype: data type
176
+ local_files_only: prohibited download model
177
+ force_download: forced download model
178
+ resume_download: re-download model
179
+ ldm_speed_up: use the `torch.compile` api to speed up unet
180
+ enable_xformers: enable memory efficient attention from [xFormers]
181
+ gradient_checkpoint: activates gradient checkpointing for the current model
182
+ lora_path: load LoRA checkpoint
183
+ unet_path: load unet checkpoint
184
+
185
+ Returns:
186
+ diffusers.UNet
187
+ """
188
+
189
+ # get model id
190
+ model_id = DiffusersModels.get(model_id, model_id)
191
+
192
+ # process UNet model
193
+ unet = UNet2DConditionModel.from_pretrained(
194
+ model_id,
195
+ subfolder="unet",
196
+ torch_dtype=torch_dtype,
197
+ local_files_only=local_files_only,
198
+ force_download=force_download,
199
+ resume_download=resume_download,
200
+ ).to(device)
201
+
202
+ print(f"load diffusers UNet: {model_id}")
203
+
204
+ # process unet model if exist
205
+ if unet_path is not None and pathlib.Path(unet_path).exists():
206
+ print(f"=> load u-net from {unet_path}")
207
+ unet.from_pretrained(model_id)
208
+
209
+ # process lora layers if exist
210
+ if lora_path is not None and pathlib.Path(lora_path).exists():
211
+ unet.load_attn_procs(lora_path)
212
+ print(f"=> load lora layers into U-Net from {lora_path} ...")
213
+
214
+ # torch.compile
215
+ if ldm_speed_up:
216
+ if is_torch_version(">=", "2.0.0"):
217
+ unet = torch.compile(unet, mode="reduce-overhead", fullgraph=True)
218
+ print(f"=> enable torch.compile on U-Net")
219
+ else:
220
+ print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")
221
+
222
+ # Meta xformers
223
+ if enable_xformers:
224
+ if is_xformers_available():
225
+ import xformers
226
+
227
+ xformers_version = version.parse(xformers.__version__)
228
+ if xformers_version == version.parse("0.0.16"):
229
+ print(
230
+ "xFormers 0.0.16 cannot be used for training in some GPUs. "
231
+ "If you observe problems during training, please update xFormers to at least 0.0.17. "
232
+ "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
233
+ )
234
+ print(f"=> enable xformers")
235
+ unet.enable_xformers_memory_efficient_attention()
236
+ else:
237
+ print(f"=> warning: xformers is not available.")
238
+
239
+ # gradient checkpointing
240
+ if gradient_checkpoint:
241
+ # if unet.is_gradient_checkpointing:
242
+ if True:
243
+ print(f"=> enable gradient checkpointing")
244
+ unet.enable_gradient_checkpointing()
245
+ else:
246
+ print("=> waring: gradient checkpointing is not activated for this model.")
247
+
248
+ return unet
svgdreamer/diffvg_warp/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from .diffvg_state import DiffVGState, init_pydiffvg
7
+
8
+ __all__ = [
9
+ 'DiffVGState',
10
+ 'init_pydiffvg'
11
+ ]
svgdreamer/diffvg_warp/diffvg_state.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: ximing
3
+ # Description: parent class
4
+ # Copyright (c) 2023, XiMing Xing.
5
+ # License: MIT License
6
+ import pathlib
7
+ from typing import AnyStr, List, Union
8
+ import xml.etree.ElementTree as etree
9
+
10
+ import torch
11
+ import pydiffvg
12
+
13
+
14
+ def init_pydiffvg(device: torch.device,
15
+ use_gpu: bool = torch.cuda.is_available(),
16
+ print_timing: bool = False):
17
+ pydiffvg.set_use_gpu(use_gpu)
18
+ pydiffvg.set_device(device)
19
+ pydiffvg.set_print_timing(print_timing)
20
+
21
+
22
+ class DiffVGState(torch.nn.Module):
23
+
24
+ def __init__(self,
25
+ device: torch.device,
26
+ use_gpu: bool = torch.cuda.is_available(),
27
+ print_timing: bool = False,
28
+ canvas_width: int = None,
29
+ canvas_height: int = None):
30
+ super(DiffVGState, self).__init__()
31
+ # pydiffvg device setting
32
+ self.device = device
33
+ init_pydiffvg(device, use_gpu, print_timing)
34
+
35
+ # canvas size
36
+ self.canvas_width = canvas_width
37
+ self.canvas_height = canvas_height
38
+
39
+ # record all paths
40
+ self.shapes = []
41
+ self.shape_groups = []
42
+ # record the current optimized path
43
+ self.cur_shapes = []
44
+ self.cur_shape_groups = []
45
+
46
+ # learnable SVG params
47
+ self.point_vars = []
48
+ self.color_vars = []
49
+ self.width_vars = []
50
+
51
+ def clip_curve_shape(self, *args, **kwargs):
52
+ raise NotImplementedError
53
+
54
+ def render_warp(self, seed=0):
55
+ self.clip_curve_shape()
56
+
57
+ scene_args = pydiffvg.RenderFunction.serialize_scene(
58
+ self.canvas_width, self.canvas_height, self.shapes, self.shape_groups
59
+ )
60
+ _render = pydiffvg.RenderFunction.apply
61
+ img = _render(self.canvas_width, # width
62
+ self.canvas_height, # height
63
+ 2, # num_samples_x
64
+ 2, # num_samples_y
65
+ seed, # seed
66
+ None,
67
+ *scene_args)
68
+ return img
69
+
70
+ def render_image(self, canvas_width, canvas_height, shapes, shape_groups, seed: int = 0):
71
+ scene_args = pydiffvg.RenderFunction.serialize_scene(
72
+ canvas_width, canvas_height, shapes, shape_groups
73
+ )
74
+ _render = pydiffvg.RenderFunction.apply
75
+ img = _render(canvas_width, # width
76
+ canvas_height, # height
77
+ 2, # num_samples_x
78
+ 2, # num_samples_y
79
+ seed, # seed
80
+ None,
81
+ *scene_args)
82
+ img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4])
83
+ img = img.unsqueeze(0) # convert img from HWC to NCHW
84
+ img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
85
+ return img
86
+
87
+ @staticmethod
88
+ def load_svg(path_svg):
89
+ canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
90
+ return canvas_width, canvas_height, shapes, shape_groups
91
+
92
+ def save_svg(self,
93
+ filename: Union[AnyStr, pathlib.Path],
94
+ width: int = None,
95
+ height: int = None,
96
+ shapes: List = None,
97
+ shape_groups: List = None,
98
+ use_gamma: bool = False,
99
+ background: str = None):
100
+ """
101
+ Save an SVG file with specified parameters and shapes.
102
+ Noting: New version of SVG saving function that is an adaptation of pydiffvg.save_svg.
103
+ The original version saved words resulting in incomplete glyphs.
104
+
105
+ Args:
106
+ filename (str): The path to save the SVG file.
107
+ width (int): The width of the SVG canvas.
108
+ height (int): The height of the SVG canvas.
109
+ shapes (list): A list of shapes to be included in the SVG.
110
+ shape_groups (list): A list of shape groups.
111
+ use_gamma (bool): Flag indicating whether to apply gamma correction.
112
+ background (str, optional): The background color of the SVG.
113
+
114
+ Returns:
115
+ None
116
+ """
117
+ root = etree.Element('svg')
118
+ root.set('version', '1.1')
119
+ root.set('xmlns', 'http://www.w3.org/2000/svg')
120
+ root.set('width', str(width))
121
+ root.set('height', str(height))
122
+
123
+ if background is not None:
124
+ print(f"setting background to {background}")
125
+ root.set('style', str(background))
126
+
127
+ defs = etree.SubElement(root, 'defs')
128
+ g = etree.SubElement(root, 'g')
129
+
130
+ if use_gamma:
131
+ f = etree.SubElement(defs, 'filter')
132
+ f.set('id', 'gamma')
133
+ f.set('x', '0')
134
+ f.set('y', '0')
135
+ f.set('width', '100%')
136
+ f.set('height', '100%')
137
+ gamma = etree.SubElement(f, 'feComponentTransfer')
138
+ gamma.set('color-interpolation-filters', 'sRGB')
139
+ feFuncR = etree.SubElement(gamma, 'feFuncR')
140
+ feFuncR.set('type', 'gamma')
141
+ feFuncR.set('amplitude', str(1))
142
+ feFuncR.set('exponent', str(1 / 2.2))
143
+ feFuncG = etree.SubElement(gamma, 'feFuncG')
144
+ feFuncG.set('type', 'gamma')
145
+ feFuncG.set('amplitude', str(1))
146
+ feFuncG.set('exponent', str(1 / 2.2))
147
+ feFuncB = etree.SubElement(gamma, 'feFuncB')
148
+ feFuncB.set('type', 'gamma')
149
+ feFuncB.set('amplitude', str(1))
150
+ feFuncB.set('exponent', str(1 / 2.2))
151
+ feFuncA = etree.SubElement(gamma, 'feFuncA')
152
+ feFuncA.set('type', 'gamma')
153
+ feFuncA.set('amplitude', str(1))
154
+ feFuncA.set('exponent', str(1 / 2.2))
155
+ g.set('style', 'filter:url(#gamma)')
156
+
157
+ # Store color
158
+ for i, shape_group in enumerate(shape_groups):
159
+ def add_color(shape_color, name):
160
+ if isinstance(shape_color, pydiffvg.LinearGradient):
161
+ lg = shape_color
162
+ color = etree.SubElement(defs, 'linearGradient')
163
+ color.set('id', name)
164
+ color.set('x1', str(lg.begin[0].item()))
165
+ color.set('y1', str(lg.begin[1].item()))
166
+ color.set('x2', str(lg.end[0].item()))
167
+ color.set('y2', str(lg.end[1].item()))
168
+ offsets = lg.offsets.data.cpu().numpy()
169
+ stop_colors = lg.stop_colors.data.cpu().numpy()
170
+ for j in range(offsets.shape[0]):
171
+ stop = etree.SubElement(color, 'stop')
172
+ stop.set('offset', str(offsets[j]))
173
+ c = lg.stop_colors[j, :]
174
+ stop.set('stop-color', 'rgb({}, {}, {})'.format(
175
+ int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
176
+ ))
177
+ stop.set('stop-opacity', '{}'.format(c[3]))
178
+ if isinstance(shape_color, pydiffvg.RadialGradient):
179
+ lg = shape_color
180
+ color = etree.SubElement(defs, 'radialGradient')
181
+ color.set('id', name)
182
+ color.set('cx', str(lg.center[0].item() / width))
183
+ color.set('cy', str(lg.center[1].item() / height))
184
+ # this only support width=height
185
+ color.set('r', str(lg.radius[0].item() / width))
186
+ offsets = lg.offsets.data.cpu().numpy()
187
+ stop_colors = lg.stop_colors.data.cpu().numpy()
188
+ for j in range(offsets.shape[0]):
189
+ stop = etree.SubElement(color, 'stop')
190
+ stop.set('offset', str(offsets[j]))
191
+ c = lg.stop_colors[j, :]
192
+ stop.set('stop-color', 'rgb({}, {}, {})'.format(
193
+ int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
194
+ ))
195
+ stop.set('stop-opacity', '{}'.format(c[3]))
196
+
197
+ if shape_group.fill_color is not None:
198
+ add_color(shape_group.fill_color, 'shape_{}_fill'.format(i))
199
+ if shape_group.stroke_color is not None:
200
+ add_color(shape_group.stroke_color, 'shape_{}_stroke'.format(i))
201
+
202
+ for i, shape_group in enumerate(shape_groups):
203
+ shape = shapes[shape_group.shape_ids[0]]
204
+ if isinstance(shape, pydiffvg.Circle):
205
+ shape_node = etree.SubElement(g, 'circle')
206
+ shape_node.set('r', str(shape.radius.item()))
207
+ shape_node.set('cx', str(shape.center[0].item()))
208
+ shape_node.set('cy', str(shape.center[1].item()))
209
+ elif isinstance(shape, pydiffvg.Polygon):
210
+ shape_node = etree.SubElement(g, 'polygon')
211
+ points = shape.points.data.cpu().numpy()
212
+ path_str = ''
213
+ for j in range(0, shape.points.shape[0]):
214
+ path_str += '{} {}'.format(points[j, 0], points[j, 1])
215
+ if j != shape.points.shape[0] - 1:
216
+ path_str += ' '
217
+ shape_node.set('points', path_str)
218
+ elif isinstance(shape, pydiffvg.Path):
219
+ for j, id in enumerate(shape_group.shape_ids):
220
+ shape = shapes[id]
221
+ if isinstance(shape, pydiffvg.Path):
222
+ if j == 0:
223
+ shape_node = etree.SubElement(g, 'path')
224
+ node_id = shape_node.get('id')
225
+ path_str = ''
226
+
227
+ num_segments = shape.num_control_points.shape[0]
228
+ num_control_points = shape.num_control_points.data.cpu().numpy()
229
+ points = shape.points.data.cpu().numpy()
230
+ num_points = shape.points.shape[0]
231
+ path_str += 'M {} {}'.format(points[0, 0], points[0, 1])
232
+ point_id = 1
233
+ for j in range(0, num_segments):
234
+ if num_control_points[j] == 0:
235
+ p = point_id % num_points
236
+ path_str += ' L {} {}'.format(
237
+ points[p, 0], points[p, 1])
238
+ point_id += 1
239
+ elif num_control_points[j] == 1:
240
+ p1 = (point_id + 1) % num_points
241
+ path_str += ' Q {} {} {} {}'.format(
242
+ points[point_id, 0], points[point_id, 1],
243
+ points[p1, 0], points[p1, 1])
244
+ point_id += 2
245
+ elif num_control_points[j] == 2:
246
+ p2 = (point_id + 2) % num_points
247
+ path_str += ' C {} {} {} {} {} {}'.format(
248
+ points[point_id, 0], points[point_id, 1],
249
+ points[point_id + 1, 0], points[point_id + 1, 1],
250
+ points[p2, 0], points[p2, 1])
251
+ point_id += 3
252
+ if node_id is not None:
253
+ shape_node.set('id', node_id) # add id to Path
254
+ shape_node.set('d', path_str)
255
+ elif isinstance(shape, pydiffvg.Rect):
256
+ shape_node = etree.SubElement(g, 'rect')
257
+ shape_node.set('x', str(shape.p_min[0].item()))
258
+ shape_node.set('y', str(shape.p_min[1].item()))
259
+ shape_node.set('width', str(shape.p_max[0].item() - shape.p_min[0].item()))
260
+ shape_node.set('height', str(shape.p_max[1].item() - shape.p_min[1].item()))
261
+ elif isinstance(shape, pydiffvg.Ellipse):
262
+ shape_node = etree.SubElement(g, 'ellipse')
263
+ shape_node.set('cx', str(shape.center[0].item()))
264
+ shape_node.set('cy', str(shape.center[1].item()))
265
+ shape_node.set('rx', str(shape.radius[0].item()))
266
+ shape_node.set('ry', str(shape.radius[1].item()))
267
+ else:
268
+ raise NotImplementedError(f'shape type: {type(shape)} is not involved in pydiffvg.')
269
+
270
+ shape_node.set('stroke-width', str(2 * shape.stroke_width.data.cpu().item()))
271
+ if shape_group.fill_color is not None:
272
+ if isinstance(shape_group.fill_color, pydiffvg.LinearGradient):
273
+ shape_node.set('fill', 'url(#shape_{}_fill)'.format(i))
274
+ else:
275
+ c = shape_group.fill_color.data.cpu().numpy()
276
+ shape_node.set('fill', 'rgb({}, {}, {})'.format(
277
+ int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
278
+ shape_node.set('opacity', str(c[3]))
279
+ else:
280
+ shape_node.set('fill', 'none')
281
+ if shape_group.stroke_color is not None:
282
+ if isinstance(shape_group.stroke_color, pydiffvg.LinearGradient):
283
+ shape_node.set('stroke', 'url(#shape_{}_stroke)'.format(i))
284
+ else:
285
+ c = shape_group.stroke_color.data.cpu().numpy()
286
+ shape_node.set('stroke', 'rgb({}, {}, {})'.format(
287
+ int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
288
+ shape_node.set('stroke-opacity', str(c[3]))
289
+ shape_node.set('stroke-linecap', 'round')
290
+ shape_node.set('stroke-linejoin', 'round')
291
+
292
+ with open(filename, "w") as f:
293
+ f.write(pydiffvg.prettify(root))
294
+
295
+ @staticmethod
296
+ def save_image(img, filename, gamma=1):
297
+ if torch.is_tensor(img) and torch.device != 'cpu':
298
+ img = img.detach().cpu()
299
+ pydiffvg.imwrite(img, filename, gamma=gamma)
svgdreamer/libs/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description: a self consistent system,
5
+ # including runner, trainer, loss function, EMA, optimizer, lr scheduler , and common utils.
6
+
7
+ from .model_state import ModelState
8
+ from .optim import get_optimizer
svgdreamer/libs/logging.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ import os
7
+ import sys
8
+ import errno
9
+
10
+
11
+ def get_logger(logs_dir: str, file_name: str = "log.txt"):
12
+ logger = PrintLogger(os.path.join(logs_dir, file_name))
13
+ sys.stdout = logger # record all python print
14
+ return logger
15
+
16
+
17
+ class PrintLogger(object):
18
+
19
+ def __init__(self, fpath=None):
20
+ """
21
+ python standard input/output records
22
+ """
23
+ self.console = sys.stdout
24
+ self.file = None
25
+ if fpath is not None:
26
+ mkdir_if_missing(os.path.dirname(fpath))
27
+ self.file = open(fpath, 'w')
28
+
29
+ def __del__(self):
30
+ self.close()
31
+
32
+ def __enter__(self):
33
+ pass
34
+
35
+ def __exit__(self, *args):
36
+ self.close()
37
+
38
+ def write(self, msg):
39
+ self.console.write(msg)
40
+ if self.file is not None:
41
+ self.file.write(msg)
42
+
43
+ def write_in(self, msg):
44
+ """write in log only, not console"""
45
+ if self.file is not None:
46
+ self.file.write(msg)
47
+
48
+ def flush(self):
49
+ self.console.flush()
50
+ if self.file is not None:
51
+ self.file.flush()
52
+ os.fsync(self.file.fileno())
53
+
54
+ def close(self):
55
+ self.console.close()
56
+ if self.file is not None:
57
+ self.file.close()
58
+
59
+
60
+ def mkdir_if_missing(dir_path):
61
+ try:
62
+ os.makedirs(dir_path)
63
+ except OSError as e:
64
+ if e.errno != errno.EEXIST:
65
+ raise
svgdreamer/libs/model_state.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+
6
+ from typing import Union, List
7
+ from pathlib import Path
8
+ from datetime import datetime
9
+ import logging
10
+
11
+ from omegaconf import OmegaConf, DictConfig
12
+ from pprint import pprint
13
+ import torch
14
+ from accelerate import Accelerator
15
+
16
+ from .logging import get_logger
17
+
18
+
19
+ class ModelState:
20
+ """
21
+ Handling logger and `hugging face` accelerate training
22
+
23
+ features:
24
+ - Precision
25
+ - Device
26
+ - Optimizer
27
+ - Logger (default: python system print and logging)
28
+ - Monitor (default: wandb, tensorboard)
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ args: DictConfig,
34
+ log_path_suffix: str = None,
35
+ ignore_log=False, # whether to create log file or not
36
+ ) -> None:
37
+ self.args: DictConfig = args
38
+ # set cfg
39
+ self.state_cfg = args.state
40
+ self.x_cfg = args.x
41
+
42
+ """check valid"""
43
+ mixed_precision = self.state_cfg.get("mprec")
44
+ # Bug: omegaconf convert 'no' to false
45
+ mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision
46
+
47
+ """create working space"""
48
+ # rule: ['./config'. 'method_name', 'exp_name.yaml']
49
+ # -> result_path: ./runs/{method_name}-{exp_name}, as a base folder
50
+ now_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
51
+ results_folder = self.args.get("result_path", None)
52
+ if results_folder is None:
53
+ self.result_path = Path("./workdir") / f"SVGDreamer-{now_time}"
54
+ else:
55
+ self.result_path = Path(results_folder) / f"SVGDreamer-{now_time}"
56
+
57
+ # update result_path: ./runs/{method_name}-{exp_name}/{log_path_suffix}
58
+ # noting: can be understood as "results dir / methods / ablation study / your result"
59
+ if log_path_suffix is not None:
60
+ self.result_path = self.result_path / f"{log_path_suffix}"
61
+ else:
62
+ self.result_path = self.result_path / f"SVGDreamer"
63
+
64
+ """init visualized tracker"""
65
+ # TODO: monitor with WANDB or TENSORBOARD
66
+ self.log_with = []
67
+ # if self.state_cfg.wandb:
68
+ # self.log_with.append(LoggerType.WANDB)
69
+ # if self.state_cfg.tensorboard:
70
+ # self.log_with.append(LoggerType.TENSORBOARD)
71
+
72
+ """HuggingFace Accelerator"""
73
+ self.accelerator = Accelerator(
74
+ device_placement=True,
75
+ mixed_precision=mixed_precision,
76
+ cpu=True if self.state_cfg.cpu else False,
77
+ log_with=None if len(self.log_with) == 0 else self.log_with,
78
+ project_dir=self.result_path / "vis",
79
+ )
80
+
81
+ """logs"""
82
+ if self.accelerator.is_local_main_process:
83
+ # logging
84
+ self.log = logging.getLogger(__name__)
85
+
86
+ # log results in a folder periodically
87
+ self.result_path.mkdir(parents=True, exist_ok=True)
88
+ if not ignore_log:
89
+ self.logger = get_logger(
90
+ logs_dir=self.result_path.as_posix(),
91
+ file_name=f"{now_time}-{args.seed}-log.txt"
92
+ )
93
+
94
+ print("==> system args: ")
95
+ sys_cfg = OmegaConf.masked_copy(args, ["x"])
96
+ print(sys_cfg)
97
+ print("==> yaml config args: ")
98
+ print(self.x_cfg)
99
+
100
+ print("\n***** Model State *****")
101
+ print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp}")
102
+ print(f"-> Weight dtype: {self.weight_dtype}")
103
+
104
+ if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled:
105
+ print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}")
106
+
107
+ print(f"-> Working Space: '{self.result_path}'")
108
+
109
+ """glob step"""
110
+ self.step = 0
111
+
112
+ """log process"""
113
+ self.accelerator.wait_for_everyone()
114
+ print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}')
115
+
116
+ self.print("-> state initialization complete \n")
117
+
118
+ @property
119
+ def device(self):
120
+ return self.accelerator.device
121
+
122
+ @property
123
+ def is_main_process(self):
124
+ return self.accelerator.is_main_process
125
+
126
+ @property
127
+ def weight_dtype(self):
128
+ weight_dtype = torch.float32
129
+ if self.accelerator.mixed_precision == "fp16":
130
+ weight_dtype = torch.float16
131
+ elif self.accelerator.mixed_precision == "bf16":
132
+ weight_dtype = torch.bfloat16
133
+ return weight_dtype
134
+
135
+ @property
136
+ def n_gpus(self):
137
+ return self.accelerator.num_processes
138
+
139
+ @property
140
+ def no_decay_params_names(self):
141
+ no_decay = [
142
+ "bn", "LayerNorm", "GroupNorm",
143
+ ]
144
+ return no_decay
145
+
146
+ def no_decay_params(self, model, weight_decay):
147
+ """optimization tricks"""
148
+ optimizer_grouped_parameters = [
149
+ {
150
+ "params": [
151
+ p for n, p in model.named_parameters()
152
+ if not any(nd in n for nd in self.no_decay_params_names)
153
+ ],
154
+ "weight_decay": weight_decay,
155
+ },
156
+ {
157
+ "params": [
158
+ p for n, p in model.named_parameters()
159
+ if any(nd in n for nd in self.no_decay_params_names)
160
+ ],
161
+ "weight_decay": 0.0,
162
+ },
163
+ ]
164
+ return optimizer_grouped_parameters
165
+
166
+ def optimized_params(self, model: torch.nn.Module, verbose=True) -> List:
167
+ """return parameters if `requires_grad` is True
168
+
169
+ Args:
170
+ model: pytorch models
171
+ verbose: log optimized parameters
172
+
173
+ Examples:
174
+ >>> params_optimized = self.optimized_params(uvit, verbose=True)
175
+ >>> optimizer = torch.optim.AdamW(params_optimized, lr=1e-3)
176
+
177
+ Returns:
178
+ a list of parameters
179
+ """
180
+ params_optimized = []
181
+ for key, value in model.named_parameters():
182
+ if value.requires_grad:
183
+ params_optimized.append(value)
184
+ if verbose:
185
+ self.print("\t {}, {}, {}".format(key, value.numel(), value.shape))
186
+ return params_optimized
187
+
188
+ def save_everything(self, fpath: str):
189
+ """Saving and loading the model, optimizer, RNG generators, and the GradScaler."""
190
+ if not self.accelerator.is_main_process:
191
+ return
192
+ self.accelerator.save_state(fpath)
193
+
194
+ def load_save_everything(self, fpath: str):
195
+ """Loading the model, optimizer, RNG generators, and the GradScaler."""
196
+ self.accelerator.load_state(fpath)
197
+
198
+ def save(self, milestone: Union[str, float, int], checkpoint: object) -> None:
199
+ if not self.accelerator.is_main_process:
200
+ return
201
+
202
+ torch.save(checkpoint, self.result_path / f'model-{milestone}.pt')
203
+
204
+ def save_in(self, root: Union[str, Path], checkpoint: object) -> None:
205
+ if not self.accelerator.is_main_process:
206
+ return
207
+
208
+ torch.save(checkpoint, root)
209
+
210
+ def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False):
211
+ ckpt = torch.load(path, map_location=self.device)
212
+
213
+ unwrapped_model = self.accelerator.unwrap_model(model)
214
+ if rm_module_prefix:
215
+ unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()})
216
+ else:
217
+ unwrapped_model.load_state_dict(ckpt)
218
+ return unwrapped_model
219
+
220
+ def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]):
221
+ ckpt = torch.load(path, map_location=self.accelerator.device)
222
+ self.print(f"pretrained_dict len: {len(ckpt)}")
223
+ unwrapped_model = self.accelerator.unwrap_model(model)
224
+ model_dict = unwrapped_model.state_dict()
225
+ pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict}
226
+ model_dict.update(pretrained_dict)
227
+ unwrapped_model.load_state_dict(model_dict, strict=False)
228
+ self.print(f"selected pretrained_dict: {len(model_dict)}")
229
+ return unwrapped_model
230
+
231
+ def print(self, *args, **kwargs):
232
+ """Use in replacement of `print()` to only print once per server."""
233
+ self.accelerator.print(*args, **kwargs)
234
+
235
+ def pretty_print(self, msg):
236
+ if self.accelerator.is_main_process:
237
+ pprint(dict(msg))
238
+
239
+ def close_tracker(self):
240
+ self.accelerator.end_training()
241
+
242
+ def free_memory(self):
243
+ self.accelerator.clear()
244
+
245
+ def close(self, msg: str = "Training complete."):
246
+ """Use in end of training."""
247
+ self.free_memory()
248
+
249
+ if torch.cuda.is_available():
250
+ self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
251
+ if len(self.log_with) > 0:
252
+ self.close_tracker()
253
+ self.print(msg)
svgdreamer/libs/optim.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: ximing
3
+ # Description: optimizers
4
+ # Copyright (c) 2023, XiMing Xing.
5
+ # License: MIT License
6
+ from functools import partial
7
+
8
+ import torch
9
+ from omegaconf import DictConfig
10
+
11
+
12
+ def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None):
13
+ param_dict = {}
14
+ if optimizer_name == "adam":
15
+ optimizer = partial(torch.optim.Adam, params=parameters)
16
+ if lr is not None:
17
+ optimizer = partial(torch.optim.Adam, params=parameters, lr=lr)
18
+ if config.get('betas'):
19
+ param_dict['betas'] = config.betas
20
+ if config.get('weight_decay'):
21
+ param_dict['weight_decay'] = config.weight_decay
22
+ if config.get('eps'):
23
+ param_dict['eps'] = config.eps
24
+ elif optimizer_name == "adamW":
25
+ optimizer = partial(torch.optim.AdamW, params=parameters)
26
+ if lr is not None:
27
+ optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr)
28
+ if config.get('betas'):
29
+ param_dict['betas'] = config.betas
30
+ if config.get('weight_decay'):
31
+ param_dict['weight_decay'] = config.weight_decay
32
+ if config.get('eps'):
33
+ param_dict['eps'] = config.eps
34
+ elif optimizer_name == "radam":
35
+ optimizer = partial(torch.optim.RAdam, params=parameters)
36
+ if lr is not None:
37
+ optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr)
38
+ if config.get('betas'):
39
+ param_dict['betas'] = config.betas
40
+ if config.get('weight_decay'):
41
+ param_dict['weight_decay'] = config.weight_decay
42
+ elif optimizer_name == "sgd":
43
+ optimizer = partial(torch.optim.SGD, params=parameters)
44
+ if lr is not None:
45
+ optimizer = partial(torch.optim.SGD, params=parameters, lr=lr)
46
+ if config.get('momentum'):
47
+ param_dict['momentum'] = config.momentum
48
+ if config.get('weight_decay'):
49
+ param_dict['weight_decay'] = config.weight_decay
50
+ if config.get('nesterov'):
51
+ param_dict['nesterov'] = config.nesterov
52
+ else:
53
+ raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.")
54
+
55
+ if len(param_dict.keys()) > 0:
56
+ return optimizer(**param_dict)
57
+ else:
58
+ return optimizer()
svgdreamer/painter/VPSD_pipeline.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+ import re
6
+ import PIL
7
+ from PIL import Image
8
+ from typing import Any, List, Optional, Union, Dict
9
+ from omegaconf import DictConfig
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torchvision import transforms
15
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel
16
+ from diffusers import DDIMScheduler
17
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
18
+ rescale_noise_cfg, StableDiffusionPipelineOutput)
19
+ from diffusers.models.attention_processor import LoRAAttnProcessor
20
+ from diffusers.loaders import AttnProcsLayers
21
+
22
+ from svgdreamer.diffusers_warp import init_StableDiffusion_pipeline, init_diffusers_unet
23
+
24
+
25
+ class VectorizedParticleSDSPipeline(torch.nn.Module):
26
+
27
+ def __init__(self, model_cfg: DictConfig, diffuser_cfg: DictConfig, guidance_cfg: DictConfig, device: torch.device):
28
+ super().__init__()
29
+ self.device = device
30
+ assert guidance_cfg.n_particle >= guidance_cfg.vsd_n_particle
31
+ assert guidance_cfg.n_particle >= guidance_cfg.phi_n_particle
32
+
33
+ pipe_kwargs = {
34
+ "device": self.device,
35
+ "torch_dtype": torch.float32,
36
+ "local_files_only": not diffuser_cfg.download,
37
+ "force_download": diffuser_cfg.force_download,
38
+ "resume_download": diffuser_cfg.resume_download,
39
+ "ldm_speed_up": model_cfg.ldm_speed_up,
40
+ "enable_xformers": model_cfg.enable_xformers,
41
+ "gradient_checkpoint": model_cfg.gradient_checkpoint,
42
+ "cpu_offload": model_cfg.cpu_offload,
43
+ "vae_slicing": False
44
+ }
45
+
46
+ # load pretrained model
47
+ self.sd_pipeline = init_StableDiffusion_pipeline(
48
+ model_cfg.model_id,
49
+ custom_pipeline=StableDiffusionPipeline,
50
+ custom_scheduler=DDIMScheduler,
51
+ **pipe_kwargs
52
+ )
53
+ # disable grads
54
+ self.sd_pipeline.vae.requires_grad_(False)
55
+ self.sd_pipeline.text_encoder.requires_grad_(False)
56
+ self.sd_pipeline.unet.requires_grad_(False)
57
+ # set components
58
+ self.vae = self.sd_pipeline.vae
59
+ self.unet = self.sd_pipeline.unet
60
+ self.scheduler = self.sd_pipeline.scheduler
61
+ self.tokenizer = self.sd_pipeline.tokenizer
62
+ self.text_encoder = self.sd_pipeline.text_encoder
63
+
64
+ if guidance_cfg.phi_model == 'lora':
65
+ if guidance_cfg.phi_single: # default, use the single unet
66
+ # load LoRA model from the pretrained model
67
+ unet_ = self.unet
68
+ else:
69
+ # create a new unet model
70
+ pipe_kwargs.pop('cpu_offload')
71
+ pipe_kwargs.pop('vae_slicing')
72
+ unet_ = init_diffusers_unet(model_cfg.model_id, **pipe_kwargs)
73
+
74
+ # set correct LoRA layers
75
+ self.unet_phi, phi_model_layers = self.set_lora_layers(unet_)
76
+ self.phi_params = list(phi_model_layers.parameters())
77
+ self.lora_cross_attention_kwargs = {"scale": guidance_cfg.lora_attn_scale} \
78
+ if guidance_cfg.use_attn_scale else {}
79
+ self.vae_phi = self.vae
80
+ self.vae_phi.requires_grad_(False)
81
+
82
+ elif guidance_cfg.phi_model == 'unet_simple':
83
+ self.unet_phi = UNet2DConditionModel(
84
+ sample_size=64,
85
+ in_channels=4,
86
+ out_channels=4,
87
+ layers_per_block=1,
88
+ block_out_channels=(128, 256, 384, 512),
89
+ down_block_types=(
90
+ "DownBlock2D",
91
+ "AttnDownBlock2D",
92
+ "AttnDownBlock2D",
93
+ "AttnDownBlock2D",
94
+ ),
95
+ up_block_types=(
96
+ "AttnUpBlock2D",
97
+ "AttnUpBlock2D",
98
+ "AttnUpBlock2D",
99
+ "UpBlock2D",
100
+ ),
101
+ cross_attention_dim=self.unet.config.cross_attention_dim
102
+ ).to(device)
103
+ self.phi_params = list(self.unet_phi.parameters())
104
+ self.vae_phi = self.vae
105
+ # reset lora
106
+ guidance_cfg.use_attn_scale = False
107
+ guidance_cfg.lora_attn_scale = False
108
+
109
+ # hyper-params
110
+ self.phi_single = guidance_cfg.phi_single
111
+ self.guidance_scale: float = guidance_cfg.guidance_scale
112
+ self.guidance_scale_lora: float = guidance_cfg.phi_guidance_scale
113
+ self.grad_clip_val: Union[float, None] = guidance_cfg.grad_clip_val
114
+ self.vsd_n_particle: int = guidance_cfg.vsd_n_particle
115
+ self.phi_n_particle: int = guidance_cfg.phi_n_particle
116
+ self.t_schedule: str = guidance_cfg.t_schedule
117
+ self.t_range = list(guidance_cfg.t_range)
118
+ print(
119
+ f'n_particles: {guidance_cfg.n_particle}, '
120
+ f'enhance_particles: {guidance_cfg.particle_aug}, '
121
+ f'n_particles of score: {self.vsd_n_particle}, '
122
+ f'n_particles of phi_model: {self.phi_n_particle}, \n'
123
+ f't_range: {self.t_range}, '
124
+ f't_schedule: {self.t_schedule}, \n'
125
+ f'guidance_scale: {self.guidance_scale}, phi_guidance_scale: {self.guidance_scale_lora}.'
126
+ )
127
+ print(f"phi_model: {guidance_cfg.phi_model}, "
128
+ f"use lora_cross_attn: {guidance_cfg.use_attn_scale}, "
129
+ f"lora_attn_scale: {guidance_cfg.lora_attn_scale}. \n")
130
+
131
+ # for convenience
132
+ self.num_train_timesteps = self.scheduler.config.num_train_timesteps
133
+ self.alphas = self.scheduler.alphas_cumprod.to(self.device)
134
+ self.text_embeddings = None
135
+ self.text_embedd_cond, self.text_embedd_uncond = None, None
136
+ self.text_embeddings_phi = None
137
+ self.t = None
138
+
139
+ def set_lora_layers(self, unet): # set correct lora layers
140
+ lora_attn_procs = {}
141
+ for name in unet.attn_processors.keys():
142
+ cross_attention_dim = None if name.endswith("attn1.processor") \
143
+ else unet.config.cross_attention_dim
144
+ if name.startswith("mid_block"):
145
+ hidden_size = unet.config.block_out_channels[-1]
146
+ elif name.startswith("up_blocks"):
147
+ block_id = int(name[len("up_blocks.")])
148
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
149
+ elif name.startswith("down_blocks"):
150
+ block_id = int(name[len("down_blocks.")])
151
+ hidden_size = unet.config.block_out_channels[block_id]
152
+
153
+ lora_attn_procs[name] = LoRAAttnProcessor(
154
+ hidden_size=hidden_size,
155
+ cross_attention_dim=cross_attention_dim
156
+ ).to(self.device)
157
+ unet.set_attn_processor(lora_attn_procs)
158
+ lora_layers = AttnProcsLayers(unet.attn_processors)
159
+
160
+ unet.requires_grad_(False)
161
+ for param in lora_layers.parameters():
162
+ param.requires_grad_(True)
163
+ return unet, lora_layers
164
+
165
+ @torch.no_grad()
166
+ def encode_prompt(self,
167
+ prompt,
168
+ device,
169
+ do_classifier_free_guidance,
170
+ negative_prompt=None):
171
+ # text conditional embed
172
+ text_inputs = self.tokenizer(
173
+ prompt,
174
+ padding="max_length",
175
+ max_length=self.tokenizer.model_max_length,
176
+ truncation=True,
177
+ return_tensors="pt",
178
+ )
179
+ prompt_embeds = self.text_encoder(text_inputs.input_ids.to(device))[0]
180
+
181
+ if do_classifier_free_guidance:
182
+ if negative_prompt is None:
183
+ uncond_tokens = [""]
184
+ elif isinstance(negative_prompt, str):
185
+ uncond_tokens = [negative_prompt]
186
+ else:
187
+ uncond_tokens = negative_prompt
188
+
189
+ # unconditional embed
190
+ uncond_input = self.tokenizer(
191
+ uncond_tokens,
192
+ padding="max_length",
193
+ max_length=prompt_embeds.shape[1],
194
+ truncation=True,
195
+ return_tensors="pt",
196
+ )
197
+ negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device))[0]
198
+
199
+ concat_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
200
+ return concat_prompt_embeds, negative_prompt_embeds, prompt_embeds
201
+
202
+ return prompt_embeds, None, None
203
+
204
+ def sampling(self,
205
+ vae,
206
+ unet,
207
+ scheduler,
208
+ prompt: Union[str, List[str]] = None,
209
+ height: Optional[int] = None,
210
+ width: Optional[int] = None,
211
+ num_inference_steps: int = 50,
212
+ guidance_scale: float = 7.5,
213
+ negative_prompt: Optional[Union[str, List[str]]] = None,
214
+ num_images_per_prompt: Optional[int] = 1,
215
+ eta: float = 0.0,
216
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
217
+ latents: Optional[torch.FloatTensor] = None,
218
+ output_type: Optional[str] = "pil",
219
+ return_dict: bool = True,
220
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
221
+ guidance_rescale: float = 0.0):
222
+
223
+ # 0. Default height and width to unet
224
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
225
+ height = height or unet.config.sample_size * vae_scale_factor
226
+ width = width or unet.config.sample_size * vae_scale_factor
227
+
228
+ # 2. Define call parameters
229
+ if prompt is not None and isinstance(prompt, list):
230
+ batch_size = len(prompt)
231
+ else:
232
+ batch_size = 1
233
+
234
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
235
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
236
+ # corresponds to doing no classifier free guidance.
237
+ do_classifier_free_guidance = guidance_scale > 1.0
238
+
239
+ # 3. Encode input prompt
240
+ prompt_embeds, _, _ = self.encode_prompt(
241
+ prompt,
242
+ self.device,
243
+ do_classifier_free_guidance,
244
+ negative_prompt,
245
+ )
246
+
247
+ # 4. Prepare timesteps
248
+ scheduler.set_timesteps(num_inference_steps, device=self.device)
249
+ timesteps = scheduler.timesteps
250
+
251
+ # 5. Prepare latent variables
252
+ num_channels_latents = unet.config.in_channels
253
+ latents = self.sd_pipeline.prepare_latents(
254
+ batch_size * num_images_per_prompt,
255
+ num_channels_latents,
256
+ height,
257
+ width,
258
+ prompt_embeds.dtype,
259
+ self.device,
260
+ generator,
261
+ latents,
262
+ )
263
+
264
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
265
+ extra_step_kwargs = self.sd_pipeline.prepare_extra_step_kwargs(generator, eta)
266
+
267
+ # 7. Denoising loop
268
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
269
+ with self.sd_pipeline.progress_bar(total=num_inference_steps) as progress_bar:
270
+ for i, t in enumerate(timesteps):
271
+ # expand the latents if we are doing classifier free guidance
272
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
273
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
274
+
275
+ # predict the noise residual
276
+ noise_pred = unet(
277
+ latent_model_input,
278
+ t,
279
+ encoder_hidden_states=prompt_embeds,
280
+ cross_attention_kwargs=cross_attention_kwargs,
281
+ return_dict=False,
282
+ )[0]
283
+
284
+ # perform guidance
285
+ if do_classifier_free_guidance:
286
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
287
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
288
+
289
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
290
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
291
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
292
+
293
+ # compute the previous noisy sample x_t -> x_t-1
294
+ latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
295
+
296
+ # update progress_bar
297
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
298
+ progress_bar.update()
299
+
300
+ if not output_type == "latent":
301
+ image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]
302
+ image, has_nsfw_concept = self.sd_pipeline.run_safety_checker(image, self.device, prompt_embeds.dtype)
303
+ else:
304
+ image = latents
305
+ has_nsfw_concept = None
306
+
307
+ if has_nsfw_concept is None:
308
+ do_denormalize = [True] * image.shape[0]
309
+ else:
310
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
311
+
312
+ image = self.sd_pipeline.image_processor.postprocess(image, output_type=output_type,
313
+ do_denormalize=do_denormalize)
314
+ # Offload last model to CPU
315
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
316
+ self.final_offload_hook.offload()
317
+
318
+ if not return_dict:
319
+ return (image, has_nsfw_concept)
320
+
321
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
322
+
323
+ def sample(self,
324
+ prompt,
325
+ height: Optional[int] = None,
326
+ width: Optional[int] = None,
327
+ num_inference_steps: int = 50,
328
+ guidance_scale: float = 7.5,
329
+ negative_prompt: Optional[Union[str, List[str]]] = None,
330
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
331
+ output_type: Optional[str] = "pil"):
332
+ return self.sampling(self.vae, self.unet, self.scheduler,
333
+ prompt=prompt,
334
+ height=height, width=width,
335
+ num_inference_steps=num_inference_steps,
336
+ guidance_scale=guidance_scale,
337
+ negative_prompt=negative_prompt,
338
+ generator=generator,
339
+ output_type=output_type)
340
+
341
+ def sample_lora(self,
342
+ prompt,
343
+ height: Optional[int] = None,
344
+ width: Optional[int] = None,
345
+ num_inference_steps: int = 50,
346
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
347
+ output_type: Optional[str] = "pil"):
348
+ return self.sampling(self.vae_phi, self.unet_phi, self.scheduler,
349
+ prompt=prompt,
350
+ height=height, width=width,
351
+ num_inference_steps=num_inference_steps,
352
+ guidance_scale=self.guidance_scale_lora,
353
+ generator=generator,
354
+ cross_attention_kwargs=self.lora_cross_attention_kwargs,
355
+ output_type=output_type)
356
+
357
+ def encode2latent(self, images):
358
+ images = (2 * images - 1).clamp(-1.0, 1.0) # images: [B, 3, H, W]
359
+ # encode images
360
+ latents = self.vae.encode(images).latent_dist.sample()
361
+ latents = self.vae.config.scaling_factor * latents
362
+ return latents
363
+
364
+ def get_noise_map(self, noise_pred, guidance_scale=7.5, use_cfg=True):
365
+ if use_cfg:
366
+ noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
367
+ noise_map = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
368
+ return noise_map
369
+ else:
370
+ return noise_pred
371
+
372
+ def train_phi_model(self,
373
+ pred_rgb: torch.Tensor,
374
+ new_timesteps: bool = False,
375
+ as_latent: bool = False):
376
+ # interp to 512x512 to be fed into vae.
377
+ if as_latent:
378
+ latents = pred_rgb
379
+ else:
380
+ pred_rgb_ = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
381
+ # encode image into latents with vae, requires grad!
382
+ latents = self.encode2latent(pred_rgb_)
383
+
384
+ # get phi particles
385
+ indices = torch.randperm(latents.size(0))
386
+ latents_phi = latents[indices[:self.phi_n_particle]]
387
+ latents_phi = latents_phi.detach()
388
+
389
+ # get timestep
390
+ if new_timesteps:
391
+ t = torch.randint(0, self.num_train_timesteps, (1,), device=self.device)
392
+ else:
393
+ t = self.t
394
+
395
+ noise = torch.randn_like(latents_phi)
396
+ noisy_latents = self.scheduler.add_noise(latents_phi, noise, t)
397
+
398
+ if self.scheduler.config.prediction_type == "epsilon":
399
+ target = noise
400
+ elif self.scheduler.config.prediction_type == "v_prediction":
401
+ target = self.scheduler.get_velocity(latents_phi, noise, t)
402
+ else:
403
+ raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")
404
+
405
+ # predict the noise residual and compute loss
406
+ noise_pred = self.unet_phi(
407
+ noisy_latents, t,
408
+ encoder_hidden_states=self.text_embeddings_phi,
409
+ cross_attention_kwargs=self.lora_cross_attention_kwargs,
410
+ ).sample
411
+
412
+ return F.mse_loss(noise_pred, target, reduction="mean")
413
+
414
+ def train_phi_model_refl(self,
415
+ pred_rgb: torch.Tensor,
416
+ weight: float = 1,
417
+ new_timesteps: bool = True):
418
+ # interp to 512x512 to be fed into vae.
419
+ pred_rgb_ = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
420
+ # encode image into latents with vae, requires grad!
421
+ latents = self.encode2latent(pred_rgb_)
422
+
423
+ # get phi particles
424
+ indices = torch.randperm(latents.size(0))
425
+ latents_phi = latents[indices[:self.phi_n_particle]]
426
+ latents_phi = latents_phi.detach()
427
+
428
+ # get timestep
429
+ if new_timesteps:
430
+ t = torch.randint(0, self.num_train_timesteps, (1,), device=self.device)
431
+ else:
432
+ t = self.t
433
+
434
+ noise = torch.randn_like(latents_phi)
435
+ noisy_latents = self.scheduler.add_noise(latents_phi, noise, t)
436
+
437
+ if self.scheduler.config.prediction_type == "epsilon":
438
+ target = noise
439
+ elif self.scheduler.config.prediction_type == "v_prediction":
440
+ target = self.scheduler.get_velocity(latents_phi, noise, t)
441
+ else:
442
+ raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")
443
+
444
+ # predict the noise residual and compute loss
445
+ noise_pred = self.unet_phi(
446
+ noisy_latents, t,
447
+ encoder_hidden_states=self.text_embedd_cond,
448
+ cross_attention_kwargs=self.lora_cross_attention_kwargs,
449
+ ).sample
450
+
451
+ rewards = torch.tensor(weight, dtype=torch.float32, device=self.device)
452
+ return rewards * F.mse_loss(noise_pred, target, reduction="mean")
453
+
454
+ def schedule_timestep(self, step):
455
+ min_step = int(self.num_train_timesteps * self.t_range[0])
456
+ max_step = int(self.num_train_timesteps * self.t_range[1])
457
+ if self.t_schedule == 'randint':
458
+ t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
459
+ elif re.match(r"max_([\d.]+)_(\d+)", self.t_schedule):
460
+ # Anneal time schedule
461
+ # e.g: t_schedule == 'max_0.5_200'
462
+ # [0.02, 0.98] -> [0.02, 0.5] after 200 steps
463
+ tag, t_val, step_upd = str(self.t_schedule).split('_')
464
+ t_val, step_upd = float(t_val), int(step_upd)
465
+ if step >= step_upd:
466
+ max_step = int(self.num_train_timesteps * t_val)
467
+ t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
468
+ elif re.match(r"min_([\d.]+)_(\d+)", self.t_schedule):
469
+ # Anneal time schedule
470
+ # e.g: t_schedule == 'min_0.5_200'
471
+ # [0.02, 0.98] -> [0.5, 0.98] after 200 steps
472
+ tag, t_val, step_upd = str(self.t_schedule).split('_')
473
+ t_val, step_upd = float(t_val), int(step_upd)
474
+ if step >= step_upd:
475
+ min_step = int(self.num_train_timesteps * t_val)
476
+ t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
477
+ else:
478
+ raise NotImplementedError(f"{self.t_schedule} is not support.")
479
+ return t
480
+
481
+ def set_text_embeddings(self, prompt, negative_prompt, do_classifier_free_guidance):
482
+ if self.text_embeddings is not None:
483
+ return
484
+
485
+ # encode text prompt
486
+ text_embeddings, text_embeddings_uncond, text_embeddings_cond = \
487
+ self.encode_prompt(prompt, self.device, do_classifier_free_guidance, negative_prompt=negative_prompt)
488
+
489
+ # set pretrained model text embedding
490
+ text_embeddings_uncond, text_embeddings_cond = text_embeddings.chunk(2)
491
+ self.text_embedd_uncond, self.text_embedd_cond = text_embeddings_uncond, text_embeddings_cond
492
+ text_embeddings_unconds = text_embeddings_uncond.repeat_interleave(self.vsd_n_particle, dim=0)
493
+ text_embeddings_conds = text_embeddings_cond.repeat_interleave(self.vsd_n_particle, dim=0)
494
+ text_embeddings = torch.cat([text_embeddings_unconds, text_embeddings_conds])
495
+ self.text_embeddings = text_embeddings
496
+
497
+ # set phi model text embedding
498
+ self.text_embeddings_phi = text_embeddings_cond.repeat_interleave(self.phi_n_particle, dim=0)
499
+
500
+ def x_augment(self, x: torch.Tensor, img_size: int = 512):
501
+ augment_compose = transforms.Compose([
502
+ transforms.RandomPerspective(distortion_scale=0.5, p=0.7),
503
+ transforms.RandomCrop(size=(img_size, img_size), pad_if_needed=True, padding_mode='reflect')
504
+ ])
505
+ return augment_compose(x)
506
+
507
+ def variational_score_distillation(self,
508
+ pred_rgb: torch.Tensor,
509
+ step: int,
510
+ prompt: Union[List, str],
511
+ negative_prompt: Union[List, str] = None,
512
+ grad_scale: float = 1.0,
513
+ enhance_particle: bool = False,
514
+ im_size: int = 512,
515
+ as_latent: bool = False):
516
+ bz = pred_rgb.shape[0]
517
+
518
+ # data enhancement for the input particles
519
+ pred_rgb = self.x_augment(pred_rgb, im_size) if enhance_particle else pred_rgb
520
+
521
+ # interp to 512x512 to be fed into vae.
522
+ if as_latent:
523
+ latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
524
+ else:
525
+ pred_rgb_ = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
526
+ # encode image into latents with vae, requires grad!
527
+ # latents = self.encode2latent(pred_rgb_)
528
+ latent_list = [self.encode2latent(pred_rgb_[i].unsqueeze(0)) for i in range(bz)]
529
+ latents = torch.cat(latent_list, dim=0)
530
+ latents = latents.to(self.device)
531
+
532
+ # random sample n_particle_vsd particles from latents
533
+ latents_vsd = latents[torch.randperm(bz)[:self.vsd_n_particle]]
534
+
535
+ # encode input prompt
536
+ do_classifier_free_guidance = True
537
+ self.set_text_embeddings(prompt, negative_prompt, do_classifier_free_guidance)
538
+ text_embeddings = self.text_embeddings
539
+
540
+ # timestep a.k.a noise level
541
+ self.t = self.schedule_timestep(step)
542
+
543
+ # predict the noise residual with unet, stop gradient
544
+ with torch.no_grad():
545
+ # add noise
546
+ noise = torch.randn_like(latents_vsd)
547
+ latents_noisy = self.scheduler.add_noise(latents_vsd, noise, self.t)
548
+ # pred noise
549
+ latent_model_input = torch.cat([latents_noisy] * 2) if do_classifier_free_guidance else latents_noisy
550
+ # pretrained noise prediction network
551
+ noise_pred_pretrain = self.unet(
552
+ latent_model_input, self.t,
553
+ encoder_hidden_states=text_embeddings,
554
+ cross_attention_kwargs={'scale': 0.0} if self.phi_single else {}
555
+ ).sample
556
+
557
+ # use conditional text embeddings in phi_model
558
+ _, text_embeddings_cond = text_embeddings.chunk(2)
559
+ # estimated noise prediction network
560
+ noise_pred_est = self.unet_phi(
561
+ latents_noisy, self.t,
562
+ encoder_hidden_states=text_embeddings_cond,
563
+ cross_attention_kwargs=self.lora_cross_attention_kwargs
564
+ ).sample
565
+
566
+ # get pretrained score
567
+ noise_pred_pretrain = self.get_noise_map(noise_pred_pretrain, self.guidance_scale, use_cfg=True)
568
+ # get estimated score
569
+ noise_pred_est = self.get_noise_map(noise_pred_est, self.guidance_scale_lora, use_cfg=False)
570
+
571
+ # w(t), sigma_t^2
572
+ w = (1 - self.alphas[self.t])
573
+ grad = grad_scale * w * (noise_pred_pretrain - noise_pred_est.detach())
574
+ grad = torch.nan_to_num(grad)
575
+
576
+ # grad clipping for stable training
577
+ if self.grad_clip_val is not None and self.grad_clip_val > 0:
578
+ grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
579
+
580
+ # re-parameterization trick:
581
+ # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
582
+ target = (latents_vsd - grad).detach()
583
+ loss_vpsd = 0.5 * F.mse_loss(latents_vsd, target, reduction="sum")
584
+
585
+ return loss_vpsd, grad.norm(), latents, self.t
svgdreamer/painter/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Description:
4
+
5
+ from .painter_params import (
6
+ Painter, PainterOptimizer, CosineWithWarmupLRLambda, RandomCoordInit, NaiveCoordInit, SparseCoordInit, get_sdf)
7
+ from .component_painter_params import CompPainter, CompPainterOptimizer
8
+ from .loss import xing_loss_fn
9
+ from .VPSD_pipeline import VectorizedParticleSDSPipeline
10
+ from .diffusion_pipeline import DiffusionPipeline
svgdreamer/painter/component_painter_params.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: ximing
3
+ # Description: content painter and optimizer
4
+ # Copyright (c) 2023, XiMing Xing.
5
+ # License: MIT License
6
+
7
+ import copy
8
+ import math
9
+ import random
10
+ import pathlib
11
+ from typing import Dict, Tuple
12
+
13
+ from shapely.geometry.polygon import Polygon
14
+ from omegaconf import DictConfig
15
+ import numpy as np
16
+ import pydiffvg
17
+ import torch
18
+ from torch.optim.lr_scheduler import LambdaLR
19
+
20
+ from svgdreamer.painter import (SparseCoordInit, RandomCoordInit, NaiveCoordInit, get_sdf)
21
+ from svgdreamer.libs import get_optimizer
22
+
23
+
24
+ class CompPainter:
25
+
26
+ def __init__(
27
+ self,
28
+ style: str,
29
+ target_img: torch.Tensor,
30
+ canvas_size: Tuple[int, int] = (600, 600),
31
+ num_segments: int = 4,
32
+ segment_init: str = 'circle',
33
+ radius: int = 20,
34
+ n_grid: int = 32,
35
+ stroke_width: int = 3,
36
+ device=None,
37
+ attn_init: bool = False,
38
+ attention_map: torch.Tensor = None,
39
+ attn_prob_tau: float = None,
40
+ ):
41
+ self.style = style
42
+ self.device = device
43
+ self.target_img = target_img
44
+
45
+ # curve params
46
+ self.num_segments = num_segments
47
+ self.segment_init = segment_init
48
+ self.radius = radius
49
+
50
+ self.canvas_width, self.canvas_height = canvas_size
51
+ """pixelart params"""
52
+ self.n_grid = n_grid # divide the canvas into n grids
53
+ self.pixel_per_grid = self.canvas_width // self.n_grid
54
+ """sketch params"""
55
+ self.stroke_width = stroke_width
56
+ """iconography params"""
57
+ self.color_ref = None
58
+
59
+ self.shapes = [] # record all paths
60
+ self.shape_groups = []
61
+ self.cur_shapes, self.cur_shape_groups = [], [] # record the current optimized path
62
+ self.point_vars = []
63
+ self.color_vars = []
64
+ self.width_vars = []
65
+
66
+ # init
67
+ self.attention_map = attention_map
68
+ self.attn_init = attn_init
69
+ self.attn_prob_tau = attn_prob_tau
70
+ self.select_inds = None
71
+ self.pos_init_method = None
72
+
73
+ # background
74
+ self.para_bg = torch.tensor([1., 1., 1.], requires_grad=False, device=self.device)
75
+ # count the number of strokes
76
+ self.strokes_counter = 0 # counts the number of calls to "get_path"
77
+
78
+ def attn_init_points(self, num_paths, mask=None):
79
+ attn_map = (self.attention_map - self.attention_map.min()) / \
80
+ (self.attention_map.max() - self.attention_map.min())
81
+
82
+ attn_map_soft = np.copy(attn_map)
83
+ attn_map_soft[attn_map > 0] = softmax_t(attn_map[attn_map > 0], tau=self.attn_prob_tau)
84
+ # for visualizing
85
+ attn_thresh = np.copy(attn_map_soft)
86
+ # the probabilities associated with each entry in attn_map
87
+ attn_map_soft /= np.sum(attn_map_soft)
88
+ # select points
89
+ k = num_paths
90
+
91
+ # select k points randomly
92
+ positions = np.where(mask == 1)
93
+ positions = np.stack(positions, axis=1)
94
+ np.random.shuffle(positions)
95
+ positions = positions[:k]
96
+
97
+ # note: only use to visual
98
+ visual_coords = np.copy(positions)
99
+
100
+ canvas_coords = np.copy(positions)
101
+ canvas_coords[:, [0, 1]] = canvas_coords[:, [1, 0]]
102
+ self.select_inds = canvas_coords
103
+
104
+ # for visualizing
105
+ return attn_thresh, visual_coords
106
+
107
+ def component_wise_path_init(self, pred, init_type: str = 'sparse'):
108
+ if init_type == 'random':
109
+ self.pos_init_method = RandomCoordInit(self.canvas_height, self.canvas_width)
110
+
111
+ elif init_type == 'sparse':
112
+ assert self.target_img is not None # target_img as GT
113
+ # when initialized for the first time, the render result is None
114
+ if pred is None:
115
+ pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width)
116
+ # then pred is the render result
117
+ self.pos_init_method = SparseCoordInit(pred, self.target_img)
118
+
119
+ elif init_type == 'naive':
120
+ assert self.target_img is not None # target_img as GT
121
+ if pred is None:
122
+ pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width)
123
+ self.pos_init_method = NaiveCoordInit(pred, self.target_img)
124
+
125
+ else:
126
+ raise NotImplementedError(f"'{init_type}' is not support.")
127
+
128
+ def init_image(self, num_paths=0):
129
+ self.cur_shapes, self.cur_shape_groups = [], []
130
+
131
+ if self.style == 'pixelart': # update path definition
132
+ num_paths = self.n_grid
133
+
134
+ for i in range(num_paths):
135
+ if self.style == 'iconography':
136
+ path = self.get_path()
137
+ self.shapes.append(path)
138
+ self.cur_shapes.append(path)
139
+
140
+ wref, href = self.color_ref
141
+ wref = max(0, min(int(wref), self.canvas_width - 1))
142
+ href = max(0, min(int(href), self.canvas_height - 1))
143
+ fill_color_init = list(self.target_img[0, :, href, wref]) + [1.]
144
+ fill_color_init = torch.FloatTensor(fill_color_init)
145
+ path_group = pydiffvg.ShapeGroup(
146
+ shape_ids=torch.tensor([len(self.shapes) - 1]),
147
+ fill_color=fill_color_init,
148
+ stroke_color=None
149
+ )
150
+ self.shape_groups.append(path_group)
151
+ self.cur_shape_groups.append(path_group)
152
+
153
+ elif self.style == 'pixelart':
154
+ fill_color_init = torch.FloatTensor(np.random.uniform(size=[4]))
155
+ fill_color_init[-1] = 1.0
156
+
157
+ for j in range(num_paths):
158
+ path = self.get_path(coord=[i, j])
159
+ self.shapes.append(path)
160
+ self.cur_shapes.append(path)
161
+
162
+ path_group = pydiffvg.ShapeGroup(
163
+ shape_ids=torch.LongTensor([i * num_paths + j]),
164
+ fill_color=fill_color_init,
165
+ stroke_color=None,
166
+ )
167
+ self.shape_groups.append(path_group)
168
+ self.cur_shape_groups.append(path_group)
169
+
170
+ elif self.style == 'sketch':
171
+ path = self.get_path()
172
+ self.shapes.append(path)
173
+ self.cur_shapes.append(path)
174
+
175
+ stroke_color_init = torch.tensor([0.0, 0.0, 0.0, 1.0])
176
+ path_group = pydiffvg.ShapeGroup(
177
+ shape_ids=torch.tensor([len(self.shapes) - 1]),
178
+ fill_color=None,
179
+ stroke_color=stroke_color_init
180
+ )
181
+ self.shape_groups.append(path_group)
182
+ self.cur_shape_groups.append(path_group)
183
+
184
+ elif self.style == 'painting':
185
+ path = self.get_path()
186
+ self.shapes.append(path)
187
+ self.cur_shapes.append(path)
188
+
189
+ wref, href = self.color_ref
190
+ wref = max(0, min(int(wref), self.canvas_width - 1))
191
+ href = max(0, min(int(href), self.canvas_height - 1))
192
+ stroke_color_init = list(self.target_img[0, :, href, wref]) + [1.]
193
+ path_group = pydiffvg.ShapeGroup(
194
+ shape_ids=torch.tensor([len(self.shapes) - 1]),
195
+ fill_color=None,
196
+ stroke_color=stroke_color_init
197
+ )
198
+ self.shape_groups.append(path_group)
199
+ self.cur_shape_groups.append(path_group)
200
+
201
+ img = self.render_warp()
202
+ img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4])
203
+ img = img.unsqueeze(0) # convert img from HWC to NCHW
204
+ img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
205
+ return img
206
+
207
+ def get_image(self, step: int = 0):
208
+ img = self.render_warp(step)
209
+ img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4])
210
+ img = img.unsqueeze(0) # convert img from HWC to NCHW
211
+ img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW
212
+ return img
213
+
214
+ def get_path(self, coord=None):
215
+ num_segments = self.num_segments
216
+
217
+ points = []
218
+ if self.style == 'iconography':
219
+ num_control_points = [2] * num_segments
220
+ # init segment
221
+ if self.segment_init == 'circle':
222
+ radius = self.radius if self.radius is not None else np.random.uniform(0.5, 1)
223
+
224
+ if self.attn_init:
225
+ center = self.select_inds[self.strokes_counter] # shape: (2,)
226
+ else:
227
+ center = (random.random(), random.random()) \
228
+ if self.pos_init_method is None else self.pos_init_method()
229
+
230
+ bias = center
231
+ self.color_ref = copy.deepcopy(bias)
232
+
233
+ points = get_circle_coordinates(center, radius, k=num_segments * 3)
234
+ points = torch.FloatTensor(points)
235
+ else:
236
+ if self.attn_init:
237
+ p0 = self.select_inds[self.strokes_counter]
238
+ else:
239
+ p0 = self.pos_init_method()
240
+
241
+ self.color_ref = copy.deepcopy(p0)
242
+ points.append(p0)
243
+ for j in range(num_segments):
244
+ radius = self.radius
245
+ p1 = (p0[0] + radius * np.random.uniform(-0.5, 0.5),
246
+ p0[1] + radius * np.random.uniform(-0.5, 0.5))
247
+ p2 = (p1[0] + radius * np.random.uniform(-0.5, 0.5),
248
+ p1[1] + radius * np.random.uniform(-0.5, 0.5))
249
+ p3 = (p2[0] + radius * np.random.uniform(-0.5, 0.5),
250
+ p2[1] + radius * np.random.uniform(-0.5, 0.5))
251
+ points.append(p1)
252
+ points.append(p2)
253
+ if j < num_segments - 1:
254
+ points.append(p3)
255
+ p0 = p3
256
+ points = torch.FloatTensor(points)
257
+
258
+ path = pydiffvg.Path(num_control_points=torch.LongTensor(num_control_points),
259
+ points=points,
260
+ stroke_width=torch.tensor(0.0),
261
+ is_closed=True)
262
+ elif self.style in ['sketch', 'painting', 'ink']:
263
+ num_control_points = torch.zeros(num_segments, dtype=torch.long) + 2
264
+ points = []
265
+
266
+ if self.attn_init:
267
+ p0 = self.select_inds[self.strokes_counter]
268
+ else:
269
+ p0 = (random.random(), random.random()) \
270
+ if self.pos_init_method is None else self.pos_init_method()
271
+
272
+ self.color_ref = copy.deepcopy(p0)
273
+
274
+ points.append(p0)
275
+ for j in range(num_segments):
276
+ radius = 0.1
277
+ p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
278
+ p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5))
279
+ p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5))
280
+ points.append(p1)
281
+ points.append(p2)
282
+ points.append(p3)
283
+ p0 = p3
284
+ points = torch.tensor(points).to(self.device)
285
+
286
+ if not self.attn_init:
287
+ points[:, 0] *= self.canvas_width
288
+ points[:, 1] *= self.canvas_height
289
+
290
+ path = pydiffvg.Path(num_control_points=torch.LongTensor(num_control_points),
291
+ points=points,
292
+ stroke_width=torch.tensor(self.stroke_width),
293
+ is_closed=False)
294
+ elif self.style == 'pixelart':
295
+ x = coord[0] * self.pixel_per_grid
296
+ y = coord[1] * self.pixel_per_grid
297
+ points = torch.FloatTensor([
298
+ [x, y],
299
+ [x + self.pixel_per_grid, y],
300
+ [x + self.pixel_per_grid, y + self.pixel_per_grid],
301
+ [x, y + self.pixel_per_grid]
302
+ ]).to(self.device)
303
+ path = pydiffvg.Polygon(points=points,
304
+ stroke_width=torch.tensor(0.0),
305
+ is_closed=True)
306
+
307
+ self.strokes_counter += 1
308
+ return path
309
+
310
+ def clip_curve_shape(self):
311
+ for group in self.shape_groups:
312
+ group.fill_color.data.clamp_(0.0, 1.0)
313
+
314
+ def reinitialize_paths(self,
315
+ reinit_path: bool = False,
316
+ opacity_threshold: float = None,
317
+ area_threshold: float = None,
318
+ fpath: pathlib.Path = None):
319
+ """
320
+ reinitialize paths, also known as 'Reinitializing paths' in VectorFusion paper.
321
+
322
+ Args:
323
+ reinit_path: whether to reinitialize paths or not.
324
+ opacity_threshold: Threshold of opacity.
325
+ area_threshold: Threshold of the closed polygon area.
326
+ fpath: The path to save the reinitialized SVG.
327
+ """
328
+ if self.style == 'iconography' and reinit_path:
329
+ # re-init by opacity_threshold
330
+ select_path_ids_by_opc = []
331
+ if opacity_threshold != 0 and opacity_threshold is not None:
332
+ def get_keys_below_threshold(my_dict, threshold):
333
+ keys_below_threshold = [key for key, value in my_dict.items() if value < threshold]
334
+ return keys_below_threshold
335
+
336
+ opacity_record_ = {group.shape_ids.item(): group.fill_color.data[-1].item()
337
+ for group in self.cur_shape_groups}
338
+ # print("-> opacity_record: ", opacity_record_)
339
+ print("-> opacity_record: ", [f"{k}: {v:.3f}" for k, v in opacity_record_.items()])
340
+ select_path_ids_by_opc = get_keys_below_threshold(opacity_record_, opacity_threshold)
341
+ print("select_path_ids_by_opc: ", select_path_ids_by_opc)
342
+
343
+ # remove path by area_threshold
344
+ select_path_ids_by_area = []
345
+ if area_threshold != 0 and area_threshold is not None:
346
+ area_records = [Polygon(shape.points.detach().numpy()).area for shape in self.cur_shapes]
347
+ # print("-> area_records: ", area_records)
348
+ print("-> area_records: ", ['%.2f' % i for i in area_records])
349
+ for i, shape in enumerate(self.cur_shapes):
350
+ if Polygon(shape.points.detach().numpy()).area < area_threshold:
351
+ select_path_ids_by_area.append(shape.id)
352
+ print("select_path_ids_by_area: ", select_path_ids_by_area)
353
+
354
+ # re-init paths
355
+ reinit_union = list(set(select_path_ids_by_opc + select_path_ids_by_area))
356
+ if len(reinit_union) > 0:
357
+ for i, path in enumerate(self.cur_shapes):
358
+ if path.id in reinit_union:
359
+ self.cur_shapes[i] = self.get_path()
360
+ for i, group in enumerate(self.cur_shape_groups):
361
+ shp_ids = group.shape_ids.cpu().numpy().tolist()
362
+ if set(shp_ids).issubset(reinit_union):
363
+ fill_color_init = torch.FloatTensor(np.random.uniform(size=[4]))
364
+ fill_color_init[-1] = np.random.uniform(0.7, 1)
365
+ stroke_color_init = torch.FloatTensor(np.random.uniform(size=[4]))
366
+ self.cur_shape_groups[i] = pydiffvg.ShapeGroup(
367
+ shape_ids=torch.tensor(list(shp_ids)),
368
+ fill_color=fill_color_init,
369
+ stroke_color=stroke_color_init)
370
+ # save reinit svg
371
+ self.save_svg(fpath)
372
+
373
+ print("-" * 40)
374
+
375
+ def render_warp(self, seed=0):
376
+ scene_args = pydiffvg.RenderFunction.serialize_scene(
377
+ self.canvas_width, self.canvas_height, self.shapes, self.shape_groups
378
+ )
379
+ _render = pydiffvg.RenderFunction.apply
380
+ img = _render(self.canvas_width, # width
381
+ self.canvas_height, # height
382
+ 2, # num_samples_x
383
+ 2, # num_samples_y
384
+ seed, # seed
385
+ None,
386
+ *scene_args)
387
+ return img
388
+
389
+ def calc_distance_weight(self, loss_weight_keep):
390
+ shapes_forsdf = copy.deepcopy(self.cur_shapes)
391
+ shape_groups_forsdf = copy.deepcopy(self.cur_shape_groups)
392
+ for si in shapes_forsdf:
393
+ si.stroke_width = torch.FloatTensor([0]).to(self.device)
394
+ for sg_idx, sgi in enumerate(shape_groups_forsdf):
395
+ sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(self.device)
396
+ sgi.shape_ids = torch.LongTensor([sg_idx]).to(self.device)
397
+
398
+ sargs_forsdf = pydiffvg.RenderFunction.serialize_scene(
399
+ self.canvas_width, self.canvas_height, shapes_forsdf, shape_groups_forsdf
400
+ )
401
+ _render = pydiffvg.RenderFunction.apply
402
+ with torch.no_grad():
403
+ im_forsdf = _render(self.canvas_width, # width
404
+ self.canvas_height, # height
405
+ 2, # num_samples_x
406
+ 2, # num_samples_y
407
+ 0, # seed
408
+ None,
409
+ *sargs_forsdf)
410
+
411
+ # use alpha channel is a trick to get 0-1 image
412
+ im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy()
413
+ loss_weight = get_sdf(im_forsdf, normalize='to1')
414
+ loss_weight += loss_weight_keep
415
+ loss_weight = np.clip(loss_weight, 0, 1)
416
+ loss_weight = torch.FloatTensor(loss_weight).to(self.device)
417
+ return loss_weight
418
+
419
+ def set_points_parameters(self, id_delta=0):
420
+ self.point_vars = []
421
+ for i, path in enumerate(self.cur_shapes):
422
+ path.id = i + id_delta # set point id
423
+ path.points.requires_grad = True
424
+ self.point_vars.append(path.points)
425
+
426
+ def get_point_params(self):
427
+ return self.point_vars
428
+
429
+ def set_color_parameters(self):
430
+ self.color_vars = []
431
+ for i, group in enumerate(self.cur_shape_groups):
432
+ if group.fill_color is not None:
433
+ group.fill_color.requires_grad = True
434
+ self.color_vars.append(group.fill_color)
435
+ if group.stroke_color is not None:
436
+ group.stroke_color.requires_grad = True
437
+ self.color_vars.append(group.stroke_color)
438
+
439
+ def get_color_params(self):
440
+ return self.color_vars
441
+
442
+ def set_width_parameters(self):
443
+ # stroke`s width optimization
444
+ self.width_vars = []
445
+ for i, path in enumerate(self.shapes):
446
+ path.stroke_width.requires_grad = True
447
+ self.width_vars.append(path.stroke_width)
448
+
449
+ def get_width_params(self):
450
+ return self.width_vars
451
+
452
+ def save_svg(self, fpath):
453
+ pydiffvg.save_svg(f'{fpath}',
454
+ self.canvas_width,
455
+ self.canvas_height,
456
+ self.shapes,
457
+ self.shape_groups)
458
+
459
+ def load_svg(self, path_svg):
460
+ canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
461
+ return canvas_width, canvas_height, shapes, shape_groups
462
+
463
+
464
+ def softmax_t(x, tau=0.2):
465
+ e_x = np.exp(x / tau)
466
+ return e_x / e_x.sum()
467
+
468
+
469
+ def get_circle_coordinates(center, radius, k):
470
+ coordinates = []
471
+ cx, cy = center
472
+ angle = 2 * math.pi / k
473
+
474
+ for i in range(k):
475
+ theta = i * angle # cur angle
476
+ x = cx + radius * math.cos(theta) # x
477
+ y = cy + radius * math.sin(theta) # y
478
+ coordinates.append((x, y))
479
+
480
+ return coordinates
481
+
482
+
483
+ class LinearDecayLRLambda:
484
+
485
+ def __init__(self, init_lr, keep_ratio, decay_every, decay_ratio):
486
+ self.init_lr = init_lr
487
+ self.keep_ratio = keep_ratio
488
+ self.decay_every = decay_every
489
+ self.decay_ratio = decay_ratio
490
+
491
+ def __call__(self, n):
492
+ if n < self.keep_ratio * self.decay_every:
493
+ return self.init_lr
494
+
495
+ decay_time = n // self.decay_every
496
+ decay_step = n % self.decay_every
497
+ lr_s = self.decay_ratio ** decay_time
498
+ lr_e = self.decay_ratio ** (decay_time + 1)
499
+ r = decay_step / self.decay_every
500
+ lr = lr_s * (1 - r) + lr_e * r
501
+ return lr
502
+
503
+
504
+ class CompPainterOptimizer:
505
+
506
+ def __init__(self,
507
+ renderer: CompPainter,
508
+ style: str,
509
+ num_iter: int,
510
+ lr_config: DictConfig,
511
+ optim_bg: bool = False):
512
+ self.renderer = renderer
513
+ self.style = style
514
+ self.num_iter = num_iter
515
+ self.lr_config = lr_config
516
+ schedule_cfg = self.lr_config.schedule
517
+ self.optim_bg = optim_bg
518
+
519
+ if style == 'iconography':
520
+ self.optim_point, self.optim_color, self.optim_width = True, True, False
521
+ self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio,
522
+ self.num_iter, schedule_cfg.decay_ratio)
523
+ if style == 'pixelart':
524
+ self.optim_point, self.optim_color, self.optim_width = False, True, False
525
+ self.point_lr_lambda = None
526
+ if style == 'sketch':
527
+ self.optim_point, self.optim_color, self.optim_width = True, False, False
528
+ self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio,
529
+ self.num_iter, schedule_cfg.decay_ratio)
530
+ if style == 'ink':
531
+ self.optim_point, self.optim_color, self.optim_width = True, False, True
532
+ self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio,
533
+ self.num_iter, schedule_cfg.decay_ratio)
534
+ if style == 'painting':
535
+ self.optim_point, self.optim_color, self.optim_width = True, True, True
536
+ self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio,
537
+ self.num_iter, schedule_cfg.decay_ratio)
538
+
539
+ self.point_optimizer = None
540
+ self.color_optimizer = None
541
+ self.width_optimizer = None
542
+ self.bg_optimizer = None
543
+
544
+ self.point_scheduler = None
545
+
546
+ def init_optimizers(self, pid_delta=0):
547
+ optim_cfg = self.lr_config.optim
548
+ optim_name = optim_cfg.name
549
+
550
+ params = {}
551
+ if self.optim_point:
552
+ self.renderer.set_points_parameters(pid_delta)
553
+ params['point'] = self.renderer.get_point_params()
554
+
555
+ if len(params['point']) > 0:
556
+ self.point_optimizer = get_optimizer(optim_name, params['point'], self.lr_config.point, optim_cfg)
557
+ if self.point_lr_lambda is not None:
558
+ self.point_scheduler = LambdaLR(self.point_optimizer, lr_lambda=self.point_lr_lambda, last_epoch=-1)
559
+
560
+ if self.optim_color:
561
+ self.renderer.set_color_parameters()
562
+ params['color'] = self.renderer.get_color_params()
563
+ if len(params['color']) > 0:
564
+ self.color_optimizer = get_optimizer(optim_name, params['color'], self.lr_config.color, optim_cfg)
565
+
566
+ if self.optim_width:
567
+ self.renderer.set_width_parameters()
568
+ params['width'] = self.renderer.get_width_params()
569
+ if len(params['width']) > 0:
570
+ self.width_optimizer = get_optimizer(optim_name, params['width'], self.lr_config.width, optim_cfg)
571
+
572
+ if self.optim_bg:
573
+ self.renderer.para_bg.requires_grad = True
574
+ self.bg_optimizer = get_optimizer(optim_name, self.renderer.para_bg, self.lr_config.bg, optim_cfg)
575
+
576
+ def update_lr(self):
577
+ if self.point_scheduler is not None:
578
+ self.point_scheduler.step()
579
+
580
+ def zero_grad_(self):
581
+ if self.point_optimizer is not None:
582
+ self.point_optimizer.zero_grad()
583
+ if self.color_optimizer is not None:
584
+ self.color_optimizer.zero_grad()
585
+ if self.width_optimizer is not None:
586
+ self.width_optimizer.zero_grad()
587
+ if self.bg_optimizer is not None:
588
+ self.bg_optimizer.zero_grad()
589
+
590
+ def step_(self):
591
+ if self.point_optimizer is not None:
592
+ self.point_optimizer.step()
593
+ if self.color_optimizer is not None:
594
+ self.color_optimizer.step()
595
+ if self.width_optimizer is not None:
596
+ self.width_optimizer.step()
597
+ if self.bg_optimizer is not None:
598
+ self.bg_optimizer.step()
599
+
600
+ def get_lr(self) -> Dict:
601
+ lr = {}
602
+ if self.point_optimizer is not None:
603
+ lr['pnt'] = self.point_optimizer.param_groups[0]['lr']
604
+ if self.color_optimizer is not None:
605
+ lr['clr'] = self.color_optimizer.param_groups[0]['lr']
606
+ if self.width_optimizer is not None:
607
+ lr['wd'] = self.width_optimizer.param_groups[0]['lr']
608
+ if self.bg_optimizer is not None:
609
+ lr['bg'] = self.bg_optimizer.param_groups[0]['lr']
610
+ return lr
svgdreamer/painter/diffusion_pipeline.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) XiMing Xing. All rights reserved.
3
+ # Author: XiMing Xing
4
+ # Description:
5
+ import PIL
6
+ from PIL import Image
7
+ from typing import Any, List, Optional, Union, Dict
8
+ from omegaconf import DictConfig
9
+
10
+ import numpy as np
11
+ import torch
12
+ from diffusers import StableDiffusionPipeline
13
+ from diffusers import DDIMScheduler
14
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
15
+ rescale_noise_cfg, StableDiffusionPipelineOutput)
16
+
17
+ from svgdreamer.diffusers_warp import init_StableDiffusion_pipeline
18
+ from svgdreamer.token2attn.attn_control import AttentionStore
19
+ from svgdreamer.token2attn.ptp_utils import text_under_image, view_images
20
+
21
+
22
+ class DiffusionPipeline(torch.nn.Module):
23
+
24
+ def __init__(self, model_cfg: DictConfig, diffuser_cfg: DictConfig, device: torch.device):
25
+ super().__init__()
26
+ self.device = device
27
+
28
+ pipe_kwargs = {
29
+ "device": self.device,
30
+ "torch_dtype": torch.float32,
31
+ "local_files_only": not diffuser_cfg.download,
32
+ "force_download": diffuser_cfg.force_download,
33
+ "resume_download": diffuser_cfg.resume_download,
34
+ "ldm_speed_up": model_cfg.ldm_speed_up,
35
+ "enable_xformers": model_cfg.enable_xformers,
36
+ "gradient_checkpoint": model_cfg.gradient_checkpoint,
37
+ "cpu_offload": model_cfg.cpu_offload,
38
+ "vae_slicing": False
39
+ }
40
+
41
+ # load pretrained model
42
+ self.sd_pipeline = init_StableDiffusion_pipeline(
43
+ model_cfg.model_id,
44
+ custom_pipeline=StableDiffusionPipeline,
45
+ custom_scheduler=DDIMScheduler,
46
+ **pipe_kwargs
47
+ )
48
+ # disable grads
49
+ self.sd_pipeline.vae.requires_grad_(False)
50
+ self.sd_pipeline.text_encoder.requires_grad_(False)
51
+ self.sd_pipeline.unet.requires_grad_(False)
52
+ # set components
53
+ self.vae = self.sd_pipeline.vae
54
+ self.unet = self.sd_pipeline.unet
55
+ self.scheduler = self.sd_pipeline.scheduler
56
+ self.tokenizer = self.sd_pipeline.tokenizer
57
+ self.text_encoder = self.sd_pipeline.text_encoder
58
+
59
+ @torch.no_grad()
60
+ def encode_prompt(self,
61
+ prompt,
62
+ device,
63
+ do_classifier_free_guidance,
64
+ negative_prompt=None):
65
+ # text conditional embed
66
+ text_inputs = self.tokenizer(
67
+ prompt,
68
+ padding="max_length",
69
+ max_length=self.tokenizer.model_max_length,
70
+ truncation=True,
71
+ return_tensors="pt",
72
+ )
73
+ prompt_embeds = self.text_encoder(text_inputs.input_ids.to(device))[0]
74
+
75
+ if do_classifier_free_guidance:
76
+ if negative_prompt is None:
77
+ uncond_tokens = [""]
78
+ elif isinstance(negative_prompt, str):
79
+ uncond_tokens = [negative_prompt]
80
+ else:
81
+ uncond_tokens = negative_prompt
82
+
83
+ # unconditional embed
84
+ uncond_input = self.tokenizer(
85
+ uncond_tokens,
86
+ padding="max_length",
87
+ max_length=prompt_embeds.shape[1],
88
+ truncation=True,
89
+ return_tensors="pt",
90
+ )
91
+ negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device))[0]
92
+
93
+ concat_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
94
+ return concat_prompt_embeds, negative_prompt_embeds, prompt_embeds
95
+
96
+ return prompt_embeds, None, None
97
+
98
+ def register_attention_control(self, controller):
99
+ attn_procs = {}
100
+ cross_att_count = 0
101
+ for name in self.unet.attn_processors.keys():
102
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
103
+ if name.startswith("mid_block"):
104
+ hidden_size = self.unet.config.block_out_channels[-1]
105
+ place_in_unet = "mid"
106
+ elif name.startswith("up_blocks"):
107
+ block_id = int(name[len("up_blocks.")])
108
+ hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
109
+ place_in_unet = "up"
110
+ elif name.startswith("down_blocks"):
111
+ block_id = int(name[len("down_blocks.")])
112
+ hidden_size = self.unet.config.block_out_channels[block_id]
113
+ place_in_unet = "down"
114
+ else:
115
+ continue
116
+ cross_att_count += 1
117
+ attn_procs[name] = P2PCrossAttnProcessor(
118
+ controller=controller, place_in_unet=place_in_unet
119
+ )
120
+
121
+ self.unet.set_attn_processor(attn_procs)
122
+ controller.num_att_layers = cross_att_count
123
+
124
+ @staticmethod
125
+ def aggregate_attention(prompts,
126
+ attention_store: AttentionStore,
127
+ res: int,
128
+ from_where: List[str],
129
+ is_cross: bool,
130
+ select: int):
131
+ if isinstance(prompts, str):
132
+ prompts = [prompts]
133
+ assert isinstance(prompts, list)
134
+
135
+ out = []
136
+ attention_maps = attention_store.get_average_attention()
137
+ num_pixels = res ** 2
138
+ for location in from_where:
139
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
140
+ if item.shape[1] == num_pixels:
141
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
142
+ out.append(cross_maps)
143
+ out = torch.cat(out, dim=0)
144
+ out = out.sum(0) / out.shape[0]
145
+ return out.cpu()
146
+
147
+ def get_cross_attention(self,
148
+ prompts,
149
+ attention_store: AttentionStore,
150
+ res: int,
151
+ from_where: List[str],
152
+ select: int = 0,
153
+ save_path=None):
154
+ tokens = self.tokenizer.encode(prompts[select])
155
+ decoder = self.tokenizer.decode
156
+ # shape: [res ** 2, res ** 2, seq_len]
157
+ attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, True, select)
158
+
159
+ images_text = []
160
+ images = []
161
+ for i in range(len(tokens)):
162
+ image = attention_maps[:, :, i]
163
+ image = 255 * image / image.max()
164
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
165
+ image = image.numpy().astype(np.uint8)
166
+ image = np.array(Image.fromarray(image).resize((256, 256)))
167
+ images.append(np.copy(image))
168
+ image = text_under_image(image, decoder(int(tokens[i])))
169
+ images_text.append(image)
170
+ image_array = np.stack(images_text, axis=0)
171
+ view_images(image_array, save_image=True, fp=save_path)
172
+
173
+ return attention_maps, tokens
174
+
175
+ def get_self_attention_comp(self,
176
+ prompts,
177
+ attention_store: AttentionStore,
178
+ res: int,
179
+ from_where: List[str],
180
+ img_size: int = 224,
181
+ max_com=10,
182
+ select: int = 0,
183
+ save_path=None):
184
+ attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, False, select)
185
+ attention_maps = attention_maps.numpy().reshape((res ** 2, res ** 2))
186
+ # shape: [res ** 2, res ** 2]
187
+ u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
188
+ print(f"self-attention maps: {attention_maps.shape}, "
189
+ f"u: {u.shape}, "
190
+ f"s: {s.shape}, "
191
+ f"vh: {vh.shape}")
192
+
193
+ images = []
194
+ vh_returns = []
195
+ for i in range(max_com):
196
+ image = vh[i].reshape(res, res)
197
+ image = (image - image.min()) / (image.max() - image.min())
198
+ image = 255 * image
199
+
200
+ ret_ = Image.fromarray(image).resize((img_size, img_size), resample=PIL.Image.Resampling.BILINEAR)
201
+ vh_returns.append(np.array(ret_))
202
+
203
+ image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
204
+ image = Image.fromarray(image).resize((256, 256))
205
+ image = np.array(image)
206
+ images.append(image)
207
+ image_array = np.stack(images, axis=0)
208
+ view_images(image_array, num_rows=max_com // 10, offset_ratio=0,
209
+ save_image=True, fp=save_path / "self-attn-vh.png")
210
+
211
+ return attention_maps, (u, s, vh), np.stack(vh_returns, axis=0)
212
+
213
+ def sampling(self,
214
+ vae,
215
+ unet,
216
+ scheduler,
217
+ prompt: Union[str, List[str]] = None,
218
+ height: Optional[int] = None,
219
+ width: Optional[int] = None,
220
+ controller: AttentionStore = None, # feed attention_store as control of ptp
221
+ num_inference_steps: int = 50,
222
+ guidance_scale: float = 7.5,
223
+ negative_prompt: Optional[Union[str, List[str]]] = None,
224
+ num_images_per_prompt: Optional[int] = 1,
225
+ eta: float = 0.0,
226
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
227
+ latents: Optional[torch.FloatTensor] = None,
228
+ output_type: Optional[str] = "pil",
229
+ return_dict: bool = True,
230
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
231
+ guidance_rescale: float = 0.0):
232
+
233
+ # add attention controller
234
+ self.register_attention_control(controller)
235
+
236
+ # 0. Default height and width to unet
237
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
238
+ height = height or unet.config.sample_size * vae_scale_factor
239
+ width = width or unet.config.sample_size * vae_scale_factor
240
+
241
+ # 2. Define call parameters
242
+ if prompt is not None and isinstance(prompt, list):
243
+ batch_size = len(prompt)
244
+ else:
245
+ batch_size = 1
246
+
247
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
248
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
249
+ # corresponds to doing no classifier free guidance.
250
+ do_classifier_free_guidance = guidance_scale > 1.0
251
+
252
+ # 3. Encode input prompt
253
+ prompt_embeds, _, _ = self.encode_prompt(
254
+ prompt,
255
+ self.device,
256
+ do_classifier_free_guidance,
257
+ negative_prompt,
258
+ )
259
+
260
+ # 4. Prepare timesteps
261
+ scheduler.set_timesteps(num_inference_steps, device=self.device)
262
+ timesteps = scheduler.timesteps
263
+
264
+ # 5. Prepare latent variables
265
+ num_channels_latents = unet.config.in_channels
266
+ latents = self.sd_pipeline.prepare_latents(
267
+ batch_size * num_images_per_prompt,
268
+ num_channels_latents,
269
+ height,
270
+ width,
271
+ prompt_embeds.dtype,
272
+ self.device,
273
+ generator,
274
+ latents,
275
+ )
276
+
277
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
278
+ extra_step_kwargs = self.sd_pipeline.prepare_extra_step_kwargs(generator, eta)
279
+
280
+ # 7. Denoising loop
281
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
282
+ with self.sd_pipeline.progress_bar(total=num_inference_steps) as progress_bar:
283
+ for i, t in enumerate(timesteps):
284
+ # expand the latents if we are doing classifier free guidance
285
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
286
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
287
+
288
+ # predict the noise residual
289
+ noise_pred = unet(
290
+ latent_model_input,
291
+ t,
292
+ encoder_hidden_states=prompt_embeds,
293
+ cross_attention_kwargs=cross_attention_kwargs,
294
+ return_dict=False,
295
+ )[0]
296
+
297
+ # perform guidance
298
+ if do_classifier_free_guidance:
299
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
300
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
301
+
302
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
303
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
304
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
305
+
306
+ # compute the previous noisy sample x_t -> x_t-1
307
+ latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
308
+
309
+ # controller callback
310
+ latents = controller.step_callback(latents)
311
+
312
+ # update progress_bar
313
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
314
+ progress_bar.update()
315
+
316
+ if not output_type == "latent":
317
+ image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]
318
+ image, has_nsfw_concept = self.sd_pipeline.run_safety_checker(image, self.device, prompt_embeds.dtype)
319
+ else:
320
+ image = latents
321
+ has_nsfw_concept = None
322
+
323
+ if has_nsfw_concept is None:
324
+ do_denormalize = [True] * image.shape[0]
325
+ else:
326
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
327
+
328
+ image = self.sd_pipeline.image_processor.postprocess(image, output_type=output_type,
329
+ do_denormalize=do_denormalize)
330
+ # Offload last model to CPU
331
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
332
+ self.final_offload_hook.offload()
333
+
334
+ if not return_dict:
335
+ return (image, has_nsfw_concept)
336
+
337
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
338
+
339
+ def sample(self,
340
+ prompt,
341
+ height: Optional[int] = None,
342
+ width: Optional[int] = None,
343
+ controller: AttentionStore = None, # feed attention_store as control of ptp
344
+ num_inference_steps: int = 50,
345
+ guidance_scale: float = 7.5,
346
+ negative_prompt: Optional[Union[str, List[str]]] = None,
347
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
348
+ output_type: Optional[str] = "pil"):
349
+ return self.sampling(self.vae, self.unet, self.scheduler,
350
+ prompt=prompt,
351
+ height=height, width=width,
352
+ controller=controller,
353
+ num_inference_steps=num_inference_steps,
354
+ guidance_scale=guidance_scale,
355
+ negative_prompt=negative_prompt,
356
+ generator=generator,
357
+ output_type=output_type)
358
+
359
+ def encode2latent(self, images):
360
+ images = (2 * images - 1).clamp(-1.0, 1.0) # images: [B, 3, H, W]
361
+ # encode images
362
+ latents = self.vae.encode(images).latent_dist.sample()
363
+ latents = self.vae.config.scaling_factor * latents
364
+ return latents
365
+
366
+
367
+ class P2PCrossAttnProcessor:
368
+
369
+ def __init__(self, controller, place_in_unet):
370
+ super().__init__()
371
+ self.controller = controller
372
+ self.place_in_unet = place_in_unet
373
+
374
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
375
+ batch_size, sequence_length, _ = hidden_states.shape
376
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size)
377
+
378
+ query = attn.to_q(hidden_states)
379
+
380
+ is_cross = encoder_hidden_states is not None
381
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
382
+ key = attn.to_k(encoder_hidden_states)
383
+ value = attn.to_v(encoder_hidden_states)
384
+
385
+ query = attn.head_to_batch_dim(query)
386
+ key = attn.head_to_batch_dim(key)
387
+ value = attn.head_to_batch_dim(value)
388
+
389
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
390
+
391
+ # one line change
392
+ self.controller(attention_probs, is_cross, self.place_in_unet)
393
+
394
+ hidden_states = torch.bmm(attention_probs, value)
395
+ hidden_states = attn.batch_to_head_dim(hidden_states)
396
+
397
+ # linear proj
398
+ hidden_states = attn.to_out[0](hidden_states)
399
+ # dropout
400
+ hidden_states = attn.to_out[1](hidden_states)
401
+
402
+ return hidden_states