diff --git a/ImageReward/ImageReward.py b/ImageReward/ImageReward.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcd9a1216ba000c778069380ee77a9f4dabd28b --- /dev/null +++ b/ImageReward/ImageReward.py @@ -0,0 +1,177 @@ +''' +@File : ImageReward.py +@Time : 2023/01/28 19:53:00 +@Auther : Jiazheng Xu +@Contact : xjz22@mails.tsinghua.edu.cn +@Description: ImageReward Reward model. +* Based on CLIP code base and improved-aesthetic-predictor code base +* https://github.com/openai/CLIP +* https://github.com/christophschuhmann/improved-aesthetic-predictor +''' + +import os +import torch +import torch.nn as nn +from PIL import Image +from .models.BLIP.blip_pretrain import BLIP_Pretrain +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +class MLP(nn.Module): + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + + self.layers = nn.Sequential( + nn.Linear(self.input_size, 1024), + # nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(1024, 128), + # nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(128, 64), + # nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(64, 16), + # nn.ReLU(), + nn.Linear(16, 1) + ) + + # initial MLP param + for name, param in self.layers.named_parameters(): + if 'weight' in name: + nn.init.normal_(param, mean=0.0, std=1.0 / (self.input_size + 1)) + if 'bias' in name: + nn.init.constant_(param, val=0) + + def forward(self, input): + return self.layers(input) + + +class ImageReward(nn.Module): + def __init__(self, med_config, device='cpu'): + super().__init__() + self.device = device + + self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config) + self.preprocess = _transform(224) + self.mlp = MLP(768) + + self.mean = 0.16717362830052426 + self.std = 1.0333394966054072 + + def score_gard(self, prompt_ids, prompt_attention_mask, image): + + image_embeds = self.blip.visual_encoder(image) + # text encode cross attention with image + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) + text_output = self.blip.text_encoder(prompt_ids, + attention_mask=prompt_attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + txt_features = text_output.last_hidden_state[:, 0, :] # (feature_dim) + rewards = self.mlp(txt_features) + rewards = (rewards - self.mean) / self.std + + return rewards + + def score(self, prompt, image): + + if (type(image).__name__ == 'list'): + _, rewards = self.inference_rank(prompt, image) + return rewards + + # text encode + text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, + return_tensors="pt").to(self.device) + + # image encode + if isinstance(image, Image.Image): + pil_image = image + elif isinstance(image, str): + if os.path.isfile(image): + pil_image = Image.open(image) + else: + raise TypeError( + r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.') + + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) + image_embeds = self.blip.visual_encoder(image) + + # text encode cross attention with image + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) + text_output = self.blip.text_encoder(text_input.input_ids, + attention_mask=text_input.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + txt_features = text_output.last_hidden_state[:, 0, :].float() # (feature_dim) + rewards = self.mlp(txt_features) + rewards = (rewards - self.mean) / self.std + + return rewards.detach().cpu().numpy().item() + + def inference_rank(self, prompt, generations_list): + + text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, + return_tensors="pt").to(self.device) + + txt_set = [] + for generation in generations_list: + # image encode + if isinstance(generation, Image.Image): + pil_image = generation + elif isinstance(generation, str): + if os.path.isfile(generation): + pil_image = Image.open(generation) + else: + raise TypeError( + r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.') + + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) + image_embeds = self.blip.visual_encoder(image) + + # text encode cross attention with image + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device) + text_output = self.blip.text_encoder(text_input.input_ids, + attention_mask=text_input.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True) + txt_set.append(text_output.last_hidden_state[:, 0, :]) + + txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim] + rewards = self.mlp(txt_features) # [image_num, 1] + rewards = (rewards - self.mean) / self.std + rewards = torch.squeeze(rewards) + _, rank = torch.sort(rewards, dim=0, descending=True) + _, indices = torch.sort(rank, dim=0) + indices = indices + 1 + + return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist() diff --git a/ImageReward/ReFL.py b/ImageReward/ReFL.py new file mode 100644 index 0000000000000000000000000000000000000000..d8718b3b5e0cf605219b3931f20b043bcd67c6c9 --- /dev/null +++ b/ImageReward/ReFL.py @@ -0,0 +1,830 @@ +''' +@File : ReFL.py +@Time : 2023/05/01 19:36:00 +@Auther : Jiazheng Xu +@Contact : xjz22@mails.tsinghua.edu.cn +@Description: ReFL Algorithm. +* Based on diffusers code base +* https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +''' + +import argparse +import logging +import math +import os +import random +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from PIL import Image +import ImageReward as RM + +from torchvision.transforms import Compose, Resize, CenterCrop, Normalize + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate +from diffusers.utils.import_utils import is_xformers_available + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.16.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "refl": ("image", "text"), +} + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--grad_scale", type=float, default=1e-3, help="Scale divided for grad loss value." + ) + parser.add_argument( + "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1." + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="checkpoint/refl", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=100, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=4, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=100, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-refl", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +class Trainer(object): + + def __init__(self, pretrained_model_name_or_path, train_data_dir, args): + + self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.train_data_dir = train_data_dir + + # Sanity checks + if args.dataset_name is None and self.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + + self.accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=logging_dir, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(self.accelerator.state, main_process_only=False) + if self.accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if self.accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + self.repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, tokenizer and models. + self.noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + self.text_encoder = CLIPTextModel.from_pretrained( + self.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + self.vae = AutoencoderKL.from_pretrained(self.pretrained_model_name_or_path, subfolder="vae", + revision=args.revision) + self.unet = UNet2DConditionModel.from_pretrained( + self.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + ) + self.reward_model = RM.load("ImageReward-v1.0", device=self.accelerator.device) + + # Freeze vae and text_encoder + self.vae.requires_grad_(False) + self.text_encoder.requires_grad_(False) + self.reward_model.requires_grad_(False) + + # Create EMA for the unet. + if args.use_ema: + self.ema_unet = UNet2DConditionModel.from_pretrained( + self.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + self.ema_unet = EMAModel(self.ema_unet.parameters(), model_cls=UNet2DConditionModel, + model_config=self.ema_unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + self.unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `self.accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + self.ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + self.ema_unet.load_state_dict(load_model.state_dict()) + self.ema_unet.to(self.accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + self.accelerator.register_save_state_pre_hook(save_model_hook) + self.accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + self.unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * self.accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + self.optimizer = optimizer_cls( + self.unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + data_files["train"] = self.train_data_dir + dataset = load_dataset( + "json", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, + return_tensors="pt" + ) + return inputs.input_ids + + def preprocess_train(examples): + examples["input_ids"] = tokenize_captions(examples) + examples["rm_input_ids"] = self.reward_model.blip.tokenizer(examples[caption_column], padding='max_length', + truncation=True, max_length=35, + return_tensors="pt").input_ids + examples["rm_attention_mask"] = self.reward_model.blip.tokenizer(examples[caption_column], + padding='max_length', truncation=True, + max_length=35, + return_tensors="pt").attention_mask + return examples + + with self.accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + self.train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + input_ids = torch.stack([example["input_ids"] for example in examples]) + rm_input_ids = torch.stack([example["rm_input_ids"] for example in examples]) + rm_attention_mask = torch.stack([example["rm_attention_mask"] for example in examples]) + input_ids = input_ids.view(-1, input_ids.shape[-1]) + rm_input_ids = rm_input_ids.view(-1, rm_input_ids.shape[-1]) + rm_attention_mask = rm_attention_mask.view(-1, rm_attention_mask.shape[-1]) + return {"input_ids": input_ids, "rm_input_ids": rm_input_ids, "rm_attention_mask": rm_attention_mask} + + # DataLoaders creation: + self.train_dataloader = torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch + overrode_max_train_steps = True + + self.lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=self.optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `self.accelerator`. + self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare( + self.unet, self.optimizer, self.train_dataloader, self.lr_scheduler + ) + + if args.use_ema: + self.ema_unet.to(self.accelerator.device) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + self.weight_dtype = torch.float32 + if self.accelerator.mixed_precision == "fp16": + self.weight_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + self.weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu and cast to self.weight_dtype + self.text_encoder.to(self.accelerator.device, dtype=self.weight_dtype) + self.vae.to(self.accelerator.device, dtype=self.weight_dtype) + self.reward_model.to(self.accelerator.device, dtype=self.weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / self.num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if self.accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + self.accelerator.init_trackers(args.tracker_project_name, tracker_config) + + def train(self, args): + + # Train! + total_batch_size = args.train_batch_size * self.accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(self.train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + self.accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + self.accelerator.print(f"Resuming from checkpoint {path}") + self.accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // self.num_update_steps_per_epoch + resume_step = resume_global_step % (self.num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), + disable=not self.accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + self.unet.train() + train_loss = 0.0 + for step, batch in enumerate(self.train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with self.accelerator.accumulate(self.unet): + encoder_hidden_states = self.text_encoder(batch["input_ids"])[0] + latents = torch.randn((args.train_batch_size, 4, 64, 64), device=self.accelerator.device) + + self.noise_scheduler.set_timesteps(40, device=self.accelerator.device) + timesteps = self.noise_scheduler.timesteps + + mid_timestep = random.randint(30, 39) + + for i, t in enumerate(timesteps[:mid_timestep]): + with torch.no_grad(): + latent_model_input = latents + latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t) + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=encoder_hidden_states, + ).sample + latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample + + latent_model_input = latents + latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, + timesteps[mid_timestep]) + noise_pred = self.unet( + latent_model_input, + timesteps[mid_timestep], + encoder_hidden_states=encoder_hidden_states, + ).sample + pred_original_sample = self.noise_scheduler.step(noise_pred, timesteps[mid_timestep], + latents).pred_original_sample.to(self.weight_dtype) + + pred_original_sample = 1 / self.vae.config.scaling_factor * pred_original_sample + image = self.vae.decode(pred_original_sample.to(self.weight_dtype)).sample + image = (image / 2 + 0.5).clamp(0, 1) + + # image encode + def _transform(): + return Compose([ + Resize(224, interpolation=BICUBIC), + CenterCrop(224), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + rm_preprocess = _transform() + image = rm_preprocess(image).to(self.accelerator.device) + + rewards = self.reward_model.score_gard(batch["rm_input_ids"], batch["rm_attention_mask"], image) + loss = F.relu(-rewards + 2) + loss = loss.mean() * args.grad_scale + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = self.accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_(self.unet.parameters(), args.max_grad_norm) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Checks if the self.accelerator has performed an optimization step behind the scenes + if self.accelerator.sync_gradients: + if args.use_ema: + self.ema_unet.step(self.unet.parameters()) + progress_bar.update(1) + global_step += 1 + self.accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if self.accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + self.accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if self.accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + self.ema_unet.store(self.unet.parameters()) + self.ema_unet.copy_to(self.unet.parameters()) + if args.use_ema: + # Switch back to the original UNet parameters. + self.ema_unet.restore(self.unet.parameters()) + + # Create the pipeline using the trained modules and save it. + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + self.unet = self.accelerator.unwrap_model(self.unet) + if args.use_ema: + self.ema_unet.copy_to(self.unet.parameters()) + + pipeline = StableDiffusionPipeline.from_pretrained( + self.pretrained_model_name_or_path, + text_encoder=self.text_encoder, + vae=self.vae, + unet=self.unet, + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + upload_folder( + repo_id=self.repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + self.accelerator.end_training() diff --git a/ImageReward/__init__.py b/ImageReward/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec7d51d8e7417474542f05883398390c37a6ba6 --- /dev/null +++ b/ImageReward/__init__.py @@ -0,0 +1,3 @@ +from .utils import * +from .models import * +from .ReFL import * \ No newline at end of file diff --git a/ImageReward/models/AestheticScore.py b/ImageReward/models/AestheticScore.py new file mode 100644 index 0000000000000000000000000000000000000000..aeefd0f515e803085b16dda2497b34babe5c684e --- /dev/null +++ b/ImageReward/models/AestheticScore.py @@ -0,0 +1,95 @@ +''' +@File : AestheticScore.py +@Time : 2023/02/12 14:54:00 +@Auther : Jiazheng Xu +@Contact : xjz22@mails.tsinghua.edu.cn +@Description: AestheticScore. +* Based on improved-aesthetic-predictor code base +* https://github.com/christophschuhmann/improved-aesthetic-predictor +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +import clip + + +# if you changed the MLP architecture during training, change it also here: +class MLP(nn.Module): + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.layers = nn.Sequential( + nn.Linear(self.input_size, 1024), + # nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(1024, 128), + # nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(128, 64), + # nn.ReLU(), + nn.Dropout(0.1), + + nn.Linear(64, 16), + # nn.ReLU(), + + nn.Linear(16, 1) + ) + + def forward(self, x): + return self.layers(x) + + +class AestheticScore(nn.Module): + def __init__(self, download_root, device='cpu'): + super().__init__() + self.device = device + self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False, + download_root=download_root) + self.mlp = MLP(768) + + if device == "cpu": + self.clip_model.float() + else: + clip.model.convert_weights( + self.clip_model) # Actually this line is unnecessary since clip by default already on float16 + + # have clip.logit_scale require no grad. + self.clip_model.logit_scale.requires_grad_(False) + + def score(self, prompt, image_path): + + if (type(image_path).__name__ == 'list'): + _, rewards = self.inference_rank(prompt, image_path) + return rewards + + # image encode + pil_image = Image.open(image_path) + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) + image_features = F.normalize(self.clip_model.encode_image(image)).float() + + # score + rewards = self.mlp(image_features) + + return rewards.detach().cpu().numpy().item() + + def inference_rank(self, prompt, generations_list): + + img_set = [] + for generations in generations_list: + # image encode + img_path = generations + pil_image = Image.open(img_path) + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) + image_features = F.normalize(self.clip_model.encode_image(image)) + img_set.append(image_features) + + img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim] + rewards = self.mlp(img_features) + rewards = torch.squeeze(rewards) + _, rank = torch.sort(rewards, dim=0, descending=True) + _, indices = torch.sort(rank, dim=0) + indices = indices + 1 + + return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist() diff --git a/ImageReward/models/BLIP/__init__.py b/ImageReward/models/BLIP/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a617e7dda333d40ed10207f44ccc3857fb18ad4 --- /dev/null +++ b/ImageReward/models/BLIP/__init__.py @@ -0,0 +1 @@ +from .blip_pretrain import * \ No newline at end of file diff --git a/ImageReward/models/BLIP/blip.py b/ImageReward/models/BLIP/blip.py new file mode 100644 index 0000000000000000000000000000000000000000..0dfdb72ab619587b62357904349358b221f631e4 --- /dev/null +++ b/ImageReward/models/BLIP/blip.py @@ -0,0 +1,70 @@ +''' + * Adapted from BLIP (https://github.com/salesforce/BLIP) +''' + +import warnings +warnings.filterwarnings("ignore") + +import torch +import os +from urllib.parse import urlparse +from timm.models.hub import download_cached_file +from transformers import BertTokenizer +from .vit import VisionTransformer, interpolate_pos_embed + + +def init_tokenizer(): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + tokenizer.add_special_tokens({'bos_token':'[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) + tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] + return tokenizer + + +def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): + + assert vit in ['base', 'large'], "vit parameter must be base or large" + if vit=='base': + vision_width = 768 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, + num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0 or drop_path_rate + ) + elif vit=='large': + vision_width = 1024 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, + num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0.1 or drop_path_rate + ) + return visual_encoder, vision_width + + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + +def load_checkpoint(model,url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) + if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): + state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], + model.visual_encoder_m) + for key in model.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape!=model.state_dict()[key].shape: + print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape) + del state_dict[key] + + msg = model.load_state_dict(state_dict,strict=False) + print('load checkpoint from %s'%url_or_filename) + return model,msg + diff --git a/ImageReward/models/BLIP/blip_pretrain.py b/ImageReward/models/BLIP/blip_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..793cb07944810eebe1d28f26aa19482b0abcf0a5 --- /dev/null +++ b/ImageReward/models/BLIP/blip_pretrain.py @@ -0,0 +1,43 @@ +''' + * Adapted from BLIP (https://github.com/salesforce/BLIP) +''' + +import transformers +transformers.logging.set_verbosity_error() + +from torch import nn +import os +from .med import BertConfig, BertModel +from .blip import create_vit, init_tokenizer + +class BLIP_Pretrain(nn.Module): + def __init__(self, + med_config = "med_config.json", + image_size = 224, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + embed_dim = 256, + queue_size = 57600, + momentum = 0.995, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0) + + self.tokenizer = init_tokenizer() + encoder_config = BertConfig.from_json_file(med_config) + encoder_config.encoder_width = vision_width + self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) + + text_width = self.text_encoder.config.hidden_size + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + diff --git a/ImageReward/models/BLIP/med.py b/ImageReward/models/BLIP/med.py new file mode 100644 index 0000000000000000000000000000000000000000..426f4689833d988526c6e26cd627f30975ab7606 --- /dev/null +++ b/ImageReward/models/BLIP/med.py @@ -0,0 +1,947 @@ +''' + * Adapted from BLIP (https://github.com/salesforce/BLIP) + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +''' + +import math +from typing import Tuple + +import torch +from torch import Tensor, device, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode=='multimodal': + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + if reduction=='none': + lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/ImageReward/models/BLIP/vit.py b/ImageReward/models/BLIP/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5cf430090956461bc64d5ccbe427a71f50f5f2 --- /dev/null +++ b/ImageReward/models/BLIP/vit.py @@ -0,0 +1,301 @@ +''' + * Adapted from BLIP (https://github.com/salesforce/BLIP) + * Based on timm code base + * https://github.com/rwightman/pytorch-image-models/tree/master/timm +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.vision_transformer import _cfg, PatchEmbed +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_, DropPath +from timm.models.helpers import named_apply, adapt_input_conv + +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, register_hook=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def forward(self, x, register_hook=False): + x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, + use_grad_checkpointing=False, ckpt_layer=0): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, register_blk=-1): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:,:x.size(1),:] + x = self.pos_drop(x) + + for i,blk in enumerate(self.blocks): + x = blk(x, register_blk==i) + x = self.norm(x) + + return x + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) +# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: +# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) +# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) +# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: +# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) +# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = visual_encoder.patch_embed.num_patches + num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + + if orig_size!=new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) + + return new_pos_embed + else: + return pos_embed_checkpoint \ No newline at end of file diff --git a/ImageReward/models/BLIPScore.py b/ImageReward/models/BLIPScore.py new file mode 100644 index 0000000000000000000000000000000000000000..a44ed3b3d1008d659559ab1643ad251dc4b80287 --- /dev/null +++ b/ImageReward/models/BLIPScore.py @@ -0,0 +1,97 @@ +''' +@File : BLIPScore.py +@Time : 2023/02/19 20:48:00 +@Auther : Jiazheng Xu +@Contact : xjz22@mails.tsinghua.edu.cn +@Description: BLIPScore. +* Based on BLIP code base +* https://github.com/salesforce/BLIP +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from ImageReward.models.BLIP.blip_pretrain import BLIP_Pretrain +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +class BLIPScore(nn.Module): + def __init__(self, med_config, device='cpu'): + super().__init__() + self.device = device + + self.preprocess = _transform(224) + self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config) + + + def score(self, prompt, image_path): + + if (type(image_path).__name__=='list'): + _, rewards = self.inference_rank(prompt, image_path) + return rewards + + # text encode + text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) + text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text') + txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:])) + + # image encode + pil_image = Image.open(image_path) + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) + image_embeds = self.blip.visual_encoder(image) + image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1) + + # score + rewards = torch.sum(torch.mul(txt_feature, image_features), dim=1, keepdim=True) + + return rewards.detach().cpu().numpy().item() + + + def inference_rank(self, prompt, generations_list): + + text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) + text_output = self.blip.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text') + txt_feature = F.normalize(self.blip.text_proj(text_output.last_hidden_state[:,0,:])) + + txt_set = [] + img_set = [] + for generations in generations_list: + # image encode + img_path = generations + pil_image = Image.open(img_path) + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) + image_embeds = self.blip.visual_encoder(image) + image_features = F.normalize(self.blip.vision_proj(image_embeds[:,0,:]), dim=-1) + img_set.append(image_features) + txt_set.append(txt_feature) + + txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim] + img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim] + rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True) + rewards = torch.squeeze(rewards) + _, rank = torch.sort(rewards, dim=0, descending=True) + _, indices = torch.sort(rank, dim=0) + indices = indices + 1 + + return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist() \ No newline at end of file diff --git a/ImageReward/models/CLIPScore.py b/ImageReward/models/CLIPScore.py new file mode 100644 index 0000000000000000000000000000000000000000..8aba714ed0da54704a22e9a34c4c639be9c0aec3 --- /dev/null +++ b/ImageReward/models/CLIPScore.py @@ -0,0 +1,78 @@ +''' +@File : CLIPScore.py +@Time : 2023/02/12 13:14:00 +@Auther : Jiazheng Xu +@Contact : xjz22@mails.tsinghua.edu.cn +@Description: CLIPScore. +* Based on CLIP code base +* https://github.com/openai/CLIP +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +import clip + +class CLIPScore(nn.Module): + def __init__(self, download_root, device='cpu'): + super().__init__() + self.device = device + self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False, + download_root=download_root) + + if device == "cpu": + self.clip_model.float() + else: + clip.model.convert_weights(self.clip_model) # Actually this line is unnecessary since clip by default already on float16 + + # have clip.logit_scale require no grad. + self.clip_model.logit_scale.requires_grad_(False) + + + def score(self, prompt, image_path): + + if (type(image_path).__name__=='list'): + _, rewards = self.inference_rank(prompt, image_path) + return rewards + + # text encode + text = clip.tokenize(prompt, truncate=True).to(self.device) + txt_features = F.normalize(self.clip_model.encode_text(text)) + + # image encode + pil_image = Image.open(image_path) + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) + image_features = F.normalize(self.clip_model.encode_image(image)) + + # score + rewards = torch.sum(torch.mul(txt_features, image_features), dim=1, keepdim=True) + + return rewards.detach().cpu().numpy().item() + + + def inference_rank(self, prompt, generations_list): + + text = clip.tokenize(prompt, truncate=True).to(self.device) + txt_feature = F.normalize(self.clip_model.encode_text(text)) + + txt_set = [] + img_set = [] + for generations in generations_list: + # image encode + img_path = generations + pil_image = Image.open(img_path) + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) + image_features = F.normalize(self.clip_model.encode_image(image)) + img_set.append(image_features) + txt_set.append(txt_feature) + + txt_features = torch.cat(txt_set, 0).float() # [image_num, feature_dim] + img_features = torch.cat(img_set, 0).float() # [image_num, feature_dim] + rewards = torch.sum(torch.mul(txt_features, img_features), dim=1, keepdim=True) + rewards = torch.squeeze(rewards) + _, rank = torch.sort(rewards, dim=0, descending=True) + _, indices = torch.sort(rank, dim=0) + indices = indices + 1 + + return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist() \ No newline at end of file diff --git a/ImageReward/models/__init__.py b/ImageReward/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba230b0a38758ee78a4eba7caeedc259a1a4dbb --- /dev/null +++ b/ImageReward/models/__init__.py @@ -0,0 +1,4 @@ +from .AestheticScore import * +from .BLIPScore import * +from .CLIPScore import * +from .BLIP import * \ No newline at end of file diff --git a/ImageReward/utils.py b/ImageReward/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f897717701a682fbcce751f8a793a74fcf39f107 --- /dev/null +++ b/ImageReward/utils.py @@ -0,0 +1,184 @@ +''' +@File : utils.py +@Time : 2023/04/05 19:18:00 +@Auther : Jiazheng Xu +@Contact : xjz22@mails.tsinghua.edu.cn +* Based on CLIP code base +* https://github.com/openai/CLIP +* Checkpoint of CLIP/BLIP/Aesthetic are from: +* https://github.com/openai/CLIP +* https://github.com/salesforce/BLIP +* https://github.com/christophschuhmann/improved-aesthetic-predictor +''' + +import os +import urllib +from typing import Union, List +import pathlib + +import torch +from tqdm import tqdm +from huggingface_hub import hf_hub_download + +from .ImageReward import ImageReward +from .models.CLIPScore import CLIPScore +from .models.BLIPScore import BLIPScore +from .models.AestheticScore import AestheticScore + +_MODELS = { + "ImageReward-v1.0": "https://huggingface.co/THUDM/ImageReward/blob/main/ImageReward.pt", +} + + +def available_models() -> List[str]: + """Returns the names of available ImageReward models""" + return list(_MODELS.keys()) + + +def ImageReward_download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + download_target = os.path.join(root, filename) + hf_hub_download(repo_id="THUDM/ImageReward", filename=filename, local_dir=root) + return download_target + + +def load(name: str = "ImageReward-v1.0", + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + download_root: str = None, + med_config_path: str = None): + """Load a ImageReward model + + Parameters + ---------- + name: str + A model name listed by `ImageReward.available_models()`, or the path to a model checkpoint containing the state_dict + device: Union[str, torch.device] + The device to put the loaded model + download_root: str + path to download the model files; by default, it uses "~/.cache/ImageReward" + med_config_path: str + + Returns + ------- + model : torch.nn.Module + The ImageReward model + """ + if name in _MODELS: + download_root = download_root or "~/.cache/ImageReward" + download_root = pathlib.Path(download_root) + model_path = pathlib.Path(download_root) / 'ImageReward.pt' + + if not model_path.exists(): + model_path = ImageReward_download(_MODELS[name], root=download_root.as_posix()) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + print('-> load ImageReward model from %s' % model_path) + state_dict = torch.load(model_path, map_location='cpu') + + # med_config + if med_config_path is None: + med_config_root = download_root or "~/.cache/ImageReward" + med_config_root = pathlib.Path(med_config_root) + med_config_path = med_config_root / 'med_config.json' + + if not med_config_path.exists(): + med_config_path = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json", + root=med_config_root.as_posix()) + print('-> load ImageReward med_config from %s' % med_config_path) + + model = ImageReward(device=device, med_config=med_config_path).to(device) + msg = model.load_state_dict(state_dict, strict=False) + model.eval() + + return model + + +_SCORES = { + "CLIP": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "BLIP": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth", + "Aesthetic": "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac%2Blogos%2Bava1-l14-linearMSE.pth", +} + + +def available_scores() -> List[str]: + """Returns the names of available ImageReward scores""" + return list(_SCORES.keys()) + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, + unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + return download_target + + +def load_score(name: str = "CLIP", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + download_root: str = None): + """Load a ImageReward model + + Parameters + ---------- + name : str + A model name listed by `ImageReward.available_models()` + + device : Union[str, torch.device] + The device to put the loaded model + + download_root: str + path to download the model files; by default, it uses "~/.cache/ImageReward" + + Returns + ------- + model : torch.nn.Module + The ImageReward model + """ + model_download_root = download_root or os.path.expanduser("~/.cache/ImageReward") + + if name in _SCORES: + model_path = _download(_SCORES[name], model_download_root) + else: + raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}") + + print('load checkpoint from %s' % model_path) + if name == "BLIP": + state_dict = torch.load(model_path, map_location='cpu') + med_config = ImageReward_download("https://huggingface.co/THUDM/ImageReward/blob/main/med_config.json", + model_download_root) + model = BLIPScore(med_config=med_config, device=device).to(device) + model.blip.load_state_dict(state_dict['model'], strict=False) + elif name == "CLIP": + model = CLIPScore(download_root=model_download_root, device=device).to(device) + elif name == "Aesthetic": + state_dict = torch.load(model_path, map_location='cpu') + model = AestheticScore(download_root=model_download_root, device=device).to(device) + model.mlp.load_state_dict(state_dict, strict=False) + else: + raise RuntimeError(f"Score {name} not found; available scores = {available_scores()}") + + print("checkpoint loaded") + model.eval() + + return model diff --git a/README.md b/README.md index ef54ed0f49a98c3aa767191d6d3303ad20691b9c..497d5043d045784149dbeb7601f58694c2169bf9 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,91 @@ # SVGDreamer: Text Guided SVG Generation with Diffusion Model +[![cvpr24](https://img.shields.io/badge/CVPR-2024-387ADF.svg)](https://arxiv.org/abs/2312.16476) [![arXiv](https://img.shields.io/badge/arXiv-2312.16476-b31b1b.svg)](https://arxiv.org/abs/2312.16476) -[![website](https://img.shields.io/badge/website-Gitpage-yellow)](https://ximinng.github.io/SVGDreamer-project/) +[![website](https://img.shields.io/badge/Website-Gitpage-4CCD99)](https://ximinng.github.io/SVGDreamer-project/) +[![blog](https://img.shields.io/badge/Blog-ENG-9195F6)](https://huggingface.co/blog/xingxm/svgdreamer) +[![blog](https://img.shields.io/badge/Blog-CN-9195F6)](https://huggingface.co/blog/xingxm/svgdreamer) -### Code coming soon !!! +This repository contains our official implementation of the CVPR 2024 paper: SVGDreamer: Text-Guided SVG Generation with +Diffusion Model. It can generate high-quality SVGs based on text prompts. -Our project page can be found [here](https://ximinng.github.io/SVGDreamer-project/). +[//]: # (> Project Page: https://ximinng.github.io/SVGDreamer-project/) -![title](./assets/teaser1.png) -![title](./assets/teaser2.png) -![title](./assets/teaser3.png) +![title](./assets/illustrate.png) +![title](./assets/teaser_svg_asset.png) -### TODO +## :new: Update -- [ ] release the complete code +- [03/2024] 🔥 We have released the **code** for [SVGDreamer](https://ximinng.github.io/SVGDreamer-project/). +- [02/2024] 🎉 **SVGDreamer accepted by CVPR2024.** 🎉 +- [12/2023] 🔥 We have released the **[SVGDreamer Paper](https://arxiv.org/abs/2312.16476)**. SVGDreamer is + a novel text-guided vector graphics synthesis method. This method considers both the editing of vector graphics and + the quality of the synthesis. + +## 🔥Quickstart + +Before running the code, download the stable diffusion model. Append `diffuser.download=True` to the end of the script. + +### SIVE + VPSD + +**Script:** + +```shell +python svgdreamer.py x=iconography skip_sive=False "prompt='an image of Batman. full body action pose, complete detailed body. white background. empty background, high quality, 4K, ultra realistic'" token_ind=4 x.vpsd.t_schedule='randint' result_path='./logs/batman' multirun=True mv=True +``` + +- `x=iconography`(str): style configs +- `skip_sive`(bool): enable the SIVE stage +- `token_ind`(int): the index of text prompt, from 1 +- `result_path`(str): the path to save the result +- `multirun`(bool): run the script multiple times with different random seeds +- `mv`(bool): save the intermediate results of the run and record the video (This increases the run time) + +**More parameters in `./conf/x/style.yaml`, you can modify these parameters from the command line. For +example, append `x.vpsd.n_particle=4` to the end of the script.** + +### VPSD + +**Prompt:** Sydney opera house. oil painting. by Van Gogh
+**Style:** iconography
+**Preview:** + +| Particle 1 | Particle 2 | Particle 3 | Particle 4 | Particle 5 | Particle 6 | +|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------|--------------------------------------------------------| +| init p1 | init p2 | init p3 | init p4 | init p5 | init p6 | +| | | | | | | +| final p1 | final p2 | final p3 | final p4 | final p5 | final p6 | +| | | | | | | + +**Script:** + +```shell +python svgdreamer.py x=iconography "prompt='Sydney opera house. oil painting. by Van Gogh'" result_path='./logs/SydneyOperaHouse-OilPainting' +``` + +**Other Styles:** + +```shell +# Style: low-ploy +python svgdreamer.py x=lowpoly "prompt='A picture of a bald eagle. low-ploy. polygon'" result_path='./logs/BaldEagle' +# Style: pixel-art +python svgdreamer.py x=pixelart "prompt='Darth vader with lightsaber.'" result_path='./log/DarthVader' +# Style: painting +python svgdreamer.py x=painting "prompt='self portrait of Van Gogh. oil painting. cmyk portrait. multi colored. defiant and beautiful. cmyk. expressive eyes.'" result_path='./logs/VanGogh-Portrait' +# Style: sketch +python svgdreamer.py x=sketch "prompt='A free-hand drawing of A speeding Lamborghini. black and white drawing.'" result_path='./logs/Lamborghini' +# Style: ink and wash +python svgdreamer.py x=ink "prompt='Big Wild Goose Pagoda. ink style. Minimalist abstract art grayscale watercolor.'" result_path='./logs/BigWildGoosePagoda' +``` + +## 🔑 Tips + +- `x.vpsd.t_schedule` greatly affects the style of the result. Please try more. +- `neg_prompt` negative prompts affect the quality of the results. + +## 📋 TODO + +- [x] Release the code ## :books: Acknowledgement @@ -22,6 +94,8 @@ The project is built based on the following repository: - [BachiLi/diffvg](https://github.com/BachiLi/diffvg) - [huggingface/diffusers](https://github.com/huggingface/diffusers) - [ximinng/DiffSketcher](https://github.com/ximinng/DiffSketcher) +- [THUDM/ImageReward](https://github.com/THUDM/ImageReward) +- [ximinng//PyTorch-SVGRender](https://github.com/ximinng/PyTorch-SVGRender) We gratefully thank the authors for their wonderful works. @@ -31,10 +105,10 @@ If you use this code for your research, please cite the following work: ``` @article{xing2023svgdreamer, - title={SVGDreamer: Text Guided SVG Generation with Diffusion Model}, - author={Xing, Ximing and Zhou, Haitao and Wang, Chuang and Zhang, Jing and Xu, Dong and Yu, Qian}, - journal={arXiv preprint arXiv:2312.16476}, - year={2023} + title={SVGDreamer: Text Guided SVG Generation with Diffusion Model}, + author={Xing, Ximing and Zhou, Haitao and Wang, Chuang and Zhang, Jing and Xu, Dong and Yu, Qian}, + journal={arXiv preprint arXiv:2312.16476}, + year={2023} } ``` diff --git a/assets/Icon-SydneyOperaHouse/init_p0.svg b/assets/Icon-SydneyOperaHouse/init_p0.svg new file mode 100644 index 0000000000000000000000000000000000000000..4b2a702fa74f8dc30892c8ce093802b79dee692d --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/init_p0.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/init_p1.svg b/assets/Icon-SydneyOperaHouse/init_p1.svg new file mode 100644 index 0000000000000000000000000000000000000000..948f3c7b9282ba551632546f03751889b0030a01 --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/init_p1.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/init_p2.svg b/assets/Icon-SydneyOperaHouse/init_p2.svg new file mode 100644 index 0000000000000000000000000000000000000000..7a261455dbd6ddd1f4823c43d09c780445135553 --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/init_p2.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/init_p3.svg b/assets/Icon-SydneyOperaHouse/init_p3.svg new file mode 100644 index 0000000000000000000000000000000000000000..d8fd35b1186b2e83008f934a854320dfe8ba27a5 --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/init_p3.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/init_p4.svg b/assets/Icon-SydneyOperaHouse/init_p4.svg new file mode 100644 index 0000000000000000000000000000000000000000..df010ea94cf583e50d4347bbbfb9c828733e99ce --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/init_p4.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/init_p5.svg b/assets/Icon-SydneyOperaHouse/init_p5.svg new file mode 100644 index 0000000000000000000000000000000000000000..426993055648061654881b04cb838dd87e3299f5 --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/init_p5.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/p_0.svg b/assets/Icon-SydneyOperaHouse/p_0.svg new file mode 100644 index 0000000000000000000000000000000000000000..16e257cfbecb65c2c10d76d281fc44b34ae5c183 --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/p_0.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/p_1.svg b/assets/Icon-SydneyOperaHouse/p_1.svg new file mode 100644 index 0000000000000000000000000000000000000000..e2fe4390acab86acc79fe02985c33e5deb8fe71c --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/p_1.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/p_2.svg b/assets/Icon-SydneyOperaHouse/p_2.svg new file mode 100644 index 0000000000000000000000000000000000000000..954f8f0b8597b3eb69b713318f9133e68e2625dc --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/p_2.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/p_3.svg b/assets/Icon-SydneyOperaHouse/p_3.svg new file mode 100644 index 0000000000000000000000000000000000000000..c2edd4bd637066d68dd1189153e2f7e429777a67 --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/p_3.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/p_4.svg b/assets/Icon-SydneyOperaHouse/p_4.svg new file mode 100644 index 0000000000000000000000000000000000000000..ef05df2840ec471cbc4b9ba74a9c65d5be9a8c5a --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/p_4.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/Icon-SydneyOperaHouse/p_5.svg b/assets/Icon-SydneyOperaHouse/p_5.svg new file mode 100644 index 0000000000000000000000000000000000000000..7bc9bc39f405bbae8af214bc1f1b6a3b037815d1 --- /dev/null +++ b/assets/Icon-SydneyOperaHouse/p_5.svg @@ -0,0 +1,518 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/illustrate.png b/assets/illustrate.png new file mode 100644 index 0000000000000000000000000000000000000000..4aa1579df2d4bb6023f40ccccb5161c5cfb5c1a8 --- /dev/null +++ b/assets/illustrate.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bdecb562c3e69a6225f7f60b3a370372a79bcee448041850467a2fd88b808b5 +size 262728 diff --git a/assets/teaser1.png b/assets/teaser1.png deleted file mode 100644 index a8f327684cd50ca28b6e3742463db38218d27a37..0000000000000000000000000000000000000000 --- a/assets/teaser1.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e77ba5ac4b7dc26e26621646ffc9437dfd1635073bfd6736393d82e474422be3 -size 4136095 diff --git a/assets/teaser2.png b/assets/teaser2.png deleted file mode 100644 index 2185c1143fbe719d7322c6d0fa7d949d7ae1691a..0000000000000000000000000000000000000000 --- a/assets/teaser2.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:57ff58ff59454c858c957642763a46ea732afec4514ba836b591b3bb3e2369a6 -size 4907074 diff --git a/assets/teaser3.png b/assets/teaser3.png deleted file mode 100644 index 8107d05044181d90a020a3204bc473a6da38c9c5..0000000000000000000000000000000000000000 --- a/assets/teaser3.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:94c97684f04642f7d34603cd65d682c4a226e6ab9a42af70f3a7e5545644a086 -size 1435302 diff --git a/assets/teaser_cases.png b/assets/teaser_cases.png new file mode 100644 index 0000000000000000000000000000000000000000..fe3301b09443d7fb41495d2284b68288f5db615a --- /dev/null +++ b/assets/teaser_cases.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27f87a28133f8947909699a0791381d21fbb1384c1228555cb5ea07b7817bfeb +size 5345736 diff --git a/assets/teaser_more_cases.png b/assets/teaser_more_cases.png new file mode 100644 index 0000000000000000000000000000000000000000..16340e87f601db8357730ea0233545e9b57442f2 --- /dev/null +++ b/assets/teaser_more_cases.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e4e2505f41fd067829d3aeadc774a349546460803c48170e84ebae6e6646a23 +size 7700916 diff --git a/assets/teaser_svg_asset.png b/assets/teaser_svg_asset.png new file mode 100644 index 0000000000000000000000000000000000000000..33b2b54240cce84386b016725df1de2f426a4133 --- /dev/null +++ b/assets/teaser_svg_asset.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69218254bb2274825271c1b283804a026e23990b2da1b19e64d8de6ea8774666 +size 3718178 diff --git a/conf/config.yaml b/conf/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae8f33fbfed5f90512639528e4f2ee363da7a901 --- /dev/null +++ b/conf/config.yaml @@ -0,0 +1,54 @@ +#-----------------# +# Global Config # +#-----------------# + +# common args +prompt: ~ +token_ind: 1 # the index of text prompt, from 1 +neg_prompt: ~ # negative prompt +skip_sive: True # optimize from scratch without SIVE init + +# Accelerate config +state: + cpu: False # use cpu + mprec: no # mixed precision, choices: 'no', 'fp16', 'bf16' + +# Diffusers config +diffuser: + download: False # Set this variable to True the first time it runs + force_download: False + resume_download: False + +# PyDiffVG config +diffvg: + print_timing: False + +# reproduction +seed: 951222 +# multi-run +multirun: False +srange: ~ # seed range, example: [100, 100] + +# log +result_path: './workspace' +save_step: 50 + +# visual rendering process +mv: False # make video +framefreq: 5 # save the image interval +framerate: 24 # by adjusting the frame rate, you can control the playback speed of the output video + +# hydra setting +hydra: + help: + # app name, override to match the name your app is known by + app_name: 'SVGDreamer' + run: + # output directory for normal runs + # warning: make sure that the L53-55 of './libs/model_state.py' and 'dir' are modified together + dir: ./${result_path}/SVGDreamer-${now:%Y-%m-%d-%H-%M} + +# default settings +defaults: + - _self_ + - x: ~ \ No newline at end of file diff --git a/conf/x/iconography.yaml b/conf/x/iconography.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ffb235879b4900f8d030aeee3155fb01ddd4b351 --- /dev/null +++ b/conf/x/iconography.yaml @@ -0,0 +1,188 @@ +image_size: 600 # canvas size +path_svg: ~ # if you want to load a svg file and train from it +color_init: 'rand' # if skip_live=True, then use color_init to init target_img +style: "iconography" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink" + +# stable diffusion in SIVE stage +sive_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# lr and optim +sive_stage_optim: + point: 1 # control points + width: 0.1 # stroke width + color: 0.01 # fill color and stroke color + bg: 0.01 # bg in render_warp + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'linear' + keep_ratio: 0.2 + decay_ratio: 0.4 + +# SIVE rendering +sive: + attn_cfg: # init content via attn + cross_attn_res: 16 + self_attn_res: 32 + max_com: 20 + mean_comp: False + comp_idx: 0 + attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn + bg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 20 + coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point + grid: 20 + # optim + lr_schedule: True + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.001 + fg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 # number of strokes + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 15 + coord_init: 'random' # 'random', 'naive', place the first control point + grid: 20 + # optim + lr_schedule: False + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.01 + tog: # for refinement + reinit: True # if False, use fg params to init content + num_iter: 1000 + # optim + lr_schedule: False # enable lr_scheduler or not + # loss + bg_lam: 0 + fg_lam: 1 + xing_loss_weight: 0 + +# VPSD primitives +num_paths: 512 # number of strokes +trainable_bg: False # set the background to be trainable +width: 3 # stroke width +num_segments: 4 +segment_init: 'circle' # 'random' +radius: 20 +coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point +grid: 50 # divide the canvas into n grids +path_reinit: # reinitializing paths + use: True + freq: 100 # every 50 iterations + stop_step: 1000 # for VPSD fine-tuning + opacity_threshold: 0.05 + area_threshold: 64 + +# lr and optim +vpsd_stage_optim: + point: 1 + width: 0.1 + color: 0.01 + bg: 0.01 + lr_schedule: True # use lr_scheduler + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'cosine' + warmup_steps: 10 + warmup_start_lr: 0.02 + warmup_end_lr: 0.9 + cosine_end_lr: 0.4 + +# stable diffusion in VPSD stage +vpsd_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# VPSD setting +vpsd: + type: 'vpsd' + n_particle: 6 # 4, 8, 16 + vsd_n_particle: 4 # the batch size of particles + particle_aug: False # do data enhancement for the input particles + num_iter: 2000 # total iterations + guidance_scale: 7.5 # CFG value + grad_scale: 1.0 # increase or decrease the gradient + grad_clip_val: ~ # eg: 10, clip the gradient of VPSD + t_range: [ 0.02, 0.98 ] + # 'randint': random time steps, this may have a more authentic style. + # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results. + t_schedule: 'max_0.5_1500' # or 'randint' + # phi model config + phi_single: False # if False new an unet model to estimate noise + phi_model: 'lora' # 'lora', 'unet_simple' + use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not + lora_attn_scale: 1.0 # the scale of the attn based lora layer + phi_guidance_scale: 1.0 + phi_t: False # different t for phi fine-tuning + phi_update_step: 1 # enable multi-update phi model or not + phi_lr: 0.0001 # learning rate of phi model + phi_scheduler: 'ddim' + phi_n_particle: 2 # the batch size of phi_model + # ReFL config + phi_ReFL: False # enable reward feed back learning + n_phi_sample: 1 # number of samples used in ReFL + phi_sample_step: 200 # the phi log step + phi_infer_step: 50 # the phi num_inference_steps + # phi model optim + phi_optim: + name: 'adamw' + betas: [ 0.9, 0.999 ] + eps: 1e-8 + weight_decay: ~ # 1e-5 + # phi model lr learning schedule + phi_schedule: + use: False + name: 'cosine' + warmup_steps: 50 + warmup_start_lr: 0.00001 + warmup_end_lr: 0.0001 + total_step: 800 + cosine_end_lr: 0.0001 + +# reward model +reward_path: './checkpoint/ImageReward' + +# xing loss for closed-form paths +xing_loss: + use: False + weight: 0.01 diff --git a/conf/x/ink.yaml b/conf/x/ink.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aabf2804a146c114d5e0adf44e359579c303bc3c --- /dev/null +++ b/conf/x/ink.yaml @@ -0,0 +1,188 @@ +image_size: 600 # canvas size +path_svg: ~ # if you want to load a svg file and train from it +color_init: 'rand' # if skip_live=True, then use color_init to init target_img +style: "ink" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink" + +# stable diffusion in SIVE stage +sive_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# lr and optim +sive_stage_optim: + point: 1 # control points + width: 0.1 # stroke width + color: 0.01 # fill color and stroke color + bg: 0.01 # bg in render_warp + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'linear' + keep_ratio: 0.2 + decay_ratio: 0.4 + +# SIVE rendering +sive: + attn_cfg: # init content via attn + cross_attn_res: 16 + self_attn_res: 32 + max_com: 20 + mean_comp: False + comp_idx: 0 + attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn + bg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 20 + coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point + grid: 20 + # optim + lr_schedule: True + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.001 + fg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 # number of strokes + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 15 + coord_init: 'random' # 'random', 'naive', place the first control point + grid: 20 + # optim + lr_schedule: False + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.01 + tog: # for refinement + reinit: True # if False, use fg params to init content + num_iter: 1000 + # optim + lr_schedule: False # enable lr_scheduler or not + # loss + bg_lam: 0 + fg_lam: 1 + xing_loss_weight: 0 + +# VPSD primitives +num_paths: 128 # number of strokes +trainable_bg: False # set the background to be trainable +width: 6 # stroke width +num_segments: 4 +segment_init: 'circle' # 'random' +radius: 20 +coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point +grid: 50 # divide the canvas into n grids +path_reinit: # reinitializing paths + use: True + freq: 100 # every 50 iterations + stop_step: 1000 # for VPSD fine-tuning + opacity_threshold: 0.05 + area_threshold: 64 + +# lr and optim +vpsd_stage_optim: + point: 1 + width: 0.1 + color: 0.01 + bg: 0.01 + lr_schedule: True # use lr_scheduler + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'cosine' + warmup_steps: 10 + warmup_start_lr: 0.02 + warmup_end_lr: 0.9 + cosine_end_lr: 0.4 + +# stable diffusion in VPSD stage +vpsd_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# VPSD setting +vpsd: + type: 'vpsd' + n_particle: 6 # 4, 8, 16 + vsd_n_particle: 4 # the batch size of particles + particle_aug: False # do data enhancement for the input particles + num_iter: 2000 # total iterations + guidance_scale: 7.5 # CFG value + grad_scale: 1.0 # increase or decrease the gradient + grad_clip_val: ~ # eg: 10, clip the gradient of VPSD + t_range: [ 0.02, 0.98 ] + # 'randint': random time steps, this may have a more authentic style. + # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results. + t_schedule: 'randint' # or 'randint' + # phi model config + phi_single: False # if False new an unet model to estimate noise + phi_model: 'lora' # 'lora', 'unet_simple' + use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not + lora_attn_scale: 1.0 # the scale of the attn based lora layer + phi_guidance_scale: 1.0 + phi_t: False # different t for phi fine-tuning + phi_update_step: 1 # enable multi-update phi model or not + phi_lr: 0.0001 # learning rate of phi model + phi_scheduler: 'ddim' + phi_n_particle: 2 # the batch size of phi_model + # ReFL config + phi_ReFL: False # enable reward feed back learning + n_phi_sample: 1 # number of samples used in ReFL + phi_sample_step: 200 # the phi log step + phi_infer_step: 50 # the phi num_inference_steps + # phi model optim + phi_optim: + name: 'adamw' + betas: [ 0.9, 0.999 ] + eps: 1e-8 + weight_decay: ~ # 1e-5 + # phi model lr learning schedule + phi_schedule: + use: False + name: 'cosine' + warmup_steps: 50 + warmup_start_lr: 0.00001 + warmup_end_lr: 0.0001 + total_step: 800 + cosine_end_lr: 0.0001 + +# reward model +reward_path: './checkpoint/ImageReward' + +# xing loss for closed-form paths +xing_loss: + use: False + weight: 0.01 diff --git a/conf/x/lowpoly.yaml b/conf/x/lowpoly.yaml new file mode 100644 index 0000000000000000000000000000000000000000..07196462bf0391842de904d07e622a0b4dcc3899 --- /dev/null +++ b/conf/x/lowpoly.yaml @@ -0,0 +1,188 @@ +image_size: 600 # canvas size +path_svg: ~ # if you want to load a svg file and train from it +color_init: 'rand' # if skip_live=True, then use color_init to init target_img +style: "low-poly" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink" + +# stable diffusion in SIVE stage +sive_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# lr and optim +sive_stage_optim: + point: 1 # control points + width: 0.1 # stroke width + color: 0.01 # fill color and stroke color + bg: 0.01 # bg in render_warp + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'linear' + keep_ratio: 0.2 + decay_ratio: 0.4 + +# SIVE rendering +sive: + attn_cfg: # init content via attn + cross_attn_res: 16 + self_attn_res: 32 + max_com: 20 + mean_comp: False + comp_idx: 0 + attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn + bg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 20 + coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point + grid: 20 + # optim + lr_schedule: True + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.001 + fg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 # number of strokes + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 15 + coord_init: 'random' # 'random', 'naive', place the first control point + grid: 20 + # optim + lr_schedule: False + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.01 + tog: # for refinement + reinit: True # if False, use fg params to init content + num_iter: 1000 + # optim + lr_schedule: False # enable lr_scheduler or not + # loss + bg_lam: 0 + fg_lam: 1 + xing_loss_weight: 0 + +# VPSD primitives +num_paths: 512 # number of strokes +trainable_bg: False # set the background to be trainable +width: 3 # stroke width +num_segments: 4 +segment_init: 'circle' # 'random' +radius: 20 +coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point +grid: 30 # divide the canvas into n grids +path_reinit: # reinitializing paths + use: True + freq: 100 # every 50 iterations + stop_step: 1000 # for VPSD fine-tuning + opacity_threshold: 0.05 + area_threshold: 64 + +# lr and optim +vpsd_stage_optim: + point: 1 + width: 0.1 + color: 0.01 + bg: 0.01 + lr_schedule: True # use lr_scheduler + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'cosine' + warmup_steps: 10 + warmup_start_lr: 0.02 + warmup_end_lr: 0.9 + cosine_end_lr: 0.4 + +# stable diffusion in VPSD stage +vpsd_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# VPSD setting +vpsd: + type: 'vpsd' + n_particle: 6 # 4, 8, 16 + vsd_n_particle: 4 # the batch size of particles + particle_aug: False # do data enhancement for the input particles + num_iter: 1500 # total iterations + guidance_scale: 7.5 # CFG value + grad_scale: 1.0 # increase or decrease the gradient + grad_clip_val: ~ # eg: 10, clip the gradient of VPSD + t_range: [ 0.02, 0.98 ] + # 'randint': random time steps, this may have a more authentic style. + # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results. + t_schedule: 'max_0.5_1500' # or 'randint' + # phi model config + phi_single: False # if False new an unet model to estimate noise + phi_model: 'lora' # 'lora', 'unet_simple' + use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not + lora_attn_scale: 1.0 # the scale of the attn based lora layer + phi_guidance_scale: 1.0 + phi_t: False # different t for phi fine-tuning + phi_update_step: 1 # enable multi-update phi model or not + phi_lr: 0.0001 # learning rate of phi model + phi_scheduler: 'ddim' + phi_n_particle: 2 # the batch size of phi_model + # ReFL config + phi_ReFL: False # enable reward feed back learning + n_phi_sample: 1 # number of samples used in ReFL + phi_sample_step: 200 # the phi log step + phi_infer_step: 50 # the phi num_inference_steps + # phi model optim + phi_optim: + name: 'adamw' + betas: [ 0.9, 0.999 ] + eps: 1e-8 + weight_decay: ~ # 1e-5 + # phi model lr learning schedule + phi_schedule: + use: False + name: 'cosine' + warmup_steps: 50 + warmup_start_lr: 0.00001 + warmup_end_lr: 0.0001 + total_step: 800 + cosine_end_lr: 0.0001 + +# reward model +reward_path: './checkpoint/ImageReward' + +# xing loss for closed-form paths +xing_loss: + use: False + weight: 0.01 diff --git a/conf/x/painting.yaml b/conf/x/painting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c57708ab5063ada435b963c5178e0c129163283 --- /dev/null +++ b/conf/x/painting.yaml @@ -0,0 +1,188 @@ +image_size: 600 # canvas size +path_svg: ~ # if you want to load a svg file and train from it +color_init: 'rand' # if skip_live=True, then use color_init to init target_img +style: "painting" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink" + +# stable diffusion in SIVE stage +sive_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# lr and optim +sive_stage_optim: + point: 1 # control points + width: 0.1 # stroke width + color: 0.01 # fill color and stroke color + bg: 0.01 # bg in render_warp + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'linear' + keep_ratio: 0.2 + decay_ratio: 0.4 + +# SIVE rendering +sive: + attn_cfg: # init content via attn + cross_attn_res: 16 + self_attn_res: 32 + max_com: 20 + mean_comp: False + comp_idx: 0 + attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn + bg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 20 + coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point + grid: 20 + # optim + lr_schedule: True + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.001 + fg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 # number of strokes + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 15 + coord_init: 'random' # 'random', 'naive', place the first control point + grid: 20 + # optim + lr_schedule: False + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.01 + tog: # for refinement + reinit: True # if False, use fg params to init content + num_iter: 1000 + # optim + lr_schedule: False # enable lr_scheduler or not + # loss + bg_lam: 0 + fg_lam: 1 + xing_loss_weight: 0 + +# VPSD primitives +num_paths: 1500 # number of strokes +trainable_bg: False # set the background to be trainable +width: 3 # stroke width +num_segments: 4 +segment_init: 'circle' # 'random' +radius: 20 +coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point +grid: 50 # divide the canvas into n grids +path_reinit: # reinitializing paths + use: True + freq: 100 # every 50 iterations + stop_step: 1000 # for VPSD fine-tuning + opacity_threshold: 0.05 + area_threshold: 64 + +# lr and optim +vpsd_stage_optim: + point: 1 + width: 0.1 + color: 0.01 + bg: 0.01 + lr_schedule: True # use lr_scheduler + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'cosine' + warmup_steps: 10 + warmup_start_lr: 0.02 + warmup_end_lr: 0.9 + cosine_end_lr: 0.4 + +# stable diffusion in VPSD stage +vpsd_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# VPSD setting +vpsd: + type: 'vpsd' + n_particle: 6 # 4, 8, 16 + vsd_n_particle: 4 # the batch size of particles + particle_aug: False # do data enhancement for the input particles + num_iter: 2000 # total iterations + guidance_scale: 7.5 # CFG value + grad_scale: 1.0 # increase or decrease the gradient + grad_clip_val: ~ # eg: 10, clip the gradient of VPSD + t_range: [ 0.02, 0.98 ] + # 'randint': random time steps, this may have a more authentic style. + # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results. + t_schedule: 'randint' # or 'randint' + # phi model config + phi_single: False # if False new an unet model to estimate noise + phi_model: 'lora' # 'lora', 'unet_simple' + use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not + lora_attn_scale: 1.0 # the scale of the attn based lora layer + phi_guidance_scale: 1.0 + phi_t: False # different t for phi fine-tuning + phi_update_step: 1 # enable multi-update phi model or not + phi_lr: 0.0001 # learning rate of phi model + phi_scheduler: 'ddim' + phi_n_particle: 2 # the batch size of phi_model + # ReFL config + phi_ReFL: False # enable reward feed back learning + n_phi_sample: 1 # number of samples used in ReFL + phi_sample_step: 200 # the phi log step + phi_infer_step: 50 # the phi num_inference_steps + # phi model optim + phi_optim: + name: 'adamw' + betas: [ 0.9, 0.999 ] + eps: 1e-8 + weight_decay: ~ # 1e-5 + # phi model lr learning schedule + phi_schedule: + use: False + name: 'cosine' + warmup_steps: 50 + warmup_start_lr: 0.00001 + warmup_end_lr: 0.0001 + total_step: 800 + cosine_end_lr: 0.0001 + +# reward model +reward_path: './checkpoint/ImageReward' + +# xing loss for closed-form paths +xing_loss: + use: False + weight: 0.01 diff --git a/conf/x/pixelart.yaml b/conf/x/pixelart.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f26252c1d075436fa8400a256bad4c7b0203c567 --- /dev/null +++ b/conf/x/pixelart.yaml @@ -0,0 +1,188 @@ +image_size: 600 # canvas size +path_svg: ~ # if you want to load a svg file and train from it +color_init: 'rand' # if skip_live=True, then use color_init to init target_img +style: "pixelart" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink" + +# stable diffusion in SIVE stage +sive_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# lr and optim +sive_stage_optim: + point: 1 # control points + width: 0.1 # stroke width + color: 0.01 # fill color and stroke color + bg: 0.01 # bg in render_warp + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'linear' + keep_ratio: 0.2 + decay_ratio: 0.4 + +# SIVE rendering +sive: + attn_cfg: # init content via attn + cross_attn_res: 16 + self_attn_res: 32 + max_com: 20 + mean_comp: False + comp_idx: 0 + attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn + bg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 20 + coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point + grid: 20 + # optim + lr_schedule: True + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.001 + fg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 # number of strokes + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 15 + coord_init: 'random' # 'random', 'naive', place the first control point + grid: 20 + # optim + lr_schedule: False + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.01 + tog: # for refinement + reinit: True # if False, use fg params to init content + num_iter: 1000 + # optim + lr_schedule: False # enable lr_scheduler or not + # loss + bg_lam: 0 + fg_lam: 1 + xing_loss_weight: 0 + +# VPSD primitives +num_paths: 512 # number of strokes +trainable_bg: False # set the background to be trainable +width: 3 # stroke width +num_segments: 4 +segment_init: 'circle' # 'random' +radius: 20 +coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point +grid: 50 # divide the canvas into n grids +path_reinit: # reinitializing paths + use: True + freq: 100 # every 50 iterations + stop_step: 1000 # for VPSD fine-tuning + opacity_threshold: 0.05 + area_threshold: 64 + +# lr and optim +vpsd_stage_optim: + point: 1 + width: 0.1 + color: 0.01 + bg: 0.01 + lr_schedule: True # use lr_scheduler + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'cosine' + warmup_steps: 10 + warmup_start_lr: 0.02 + warmup_end_lr: 0.9 + cosine_end_lr: 0.4 + +# stable diffusion in VPSD stage +vpsd_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# VPSD setting +vpsd: + type: 'vpsd' + n_particle: 6 # 4, 8, 16 + vsd_n_particle: 4 # the batch size of particles + particle_aug: False # do data enhancement for the input particles + num_iter: 1000 # total iterations + guidance_scale: 7.5 # CFG value + grad_scale: 1.0 # increase or decrease the gradient + grad_clip_val: ~ # eg: 10, clip the gradient of VPSD + t_range: [ 0.02, 0.98 ] + # 'randint': random time steps, this may have a more authentic style. + # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results. + t_schedule: 'max_0.5_1500' # or 'randint' + # phi model config + phi_single: False # if False new an unet model to estimate noise + phi_model: 'lora' # 'lora', 'unet_simple' + use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not + lora_attn_scale: 1.0 # the scale of the attn based lora layer + phi_guidance_scale: 1.0 + phi_t: False # different t for phi fine-tuning + phi_update_step: 1 # enable multi-update phi model or not + phi_lr: 0.0001 # learning rate of phi model + phi_scheduler: 'ddim' + phi_n_particle: 2 # the batch size of phi_model + # ReFL config + phi_ReFL: False # enable reward feed back learning + n_phi_sample: 1 # number of samples used in ReFL + phi_sample_step: 200 # the phi log step + phi_infer_step: 50 # the phi num_inference_steps + # phi model optim + phi_optim: + name: 'adamw' + betas: [ 0.9, 0.999 ] + eps: 1e-8 + weight_decay: ~ # 1e-5 + # phi model lr learning schedule + phi_schedule: + use: False + name: 'cosine' + warmup_steps: 50 + warmup_start_lr: 0.00001 + warmup_end_lr: 0.0001 + total_step: 800 + cosine_end_lr: 0.0001 + +# reward model +reward_path: './checkpoint/ImageReward' + +# xing loss for closed-form paths +xing_loss: + use: False + weight: 0.01 diff --git a/conf/x/sketch.yaml b/conf/x/sketch.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cabfeb75f4776e8383ee36170c516ea624e4138f --- /dev/null +++ b/conf/x/sketch.yaml @@ -0,0 +1,188 @@ +image_size: 600 # canvas size +path_svg: ~ # if you want to load a svg file and train from it +color_init: 'rand' # if skip_live=True, then use color_init to init target_img +style: "sketch" # "iconography", "pixelart", "low-poly", "painting", "sketch", "ink" + +# stable diffusion in SIVE stage +sive_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# lr and optim +sive_stage_optim: + point: 1 # control points + width: 0.1 # stroke width + color: 0.01 # fill color and stroke color + bg: 0.01 # bg in render_warp + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'linear' + keep_ratio: 0.2 + decay_ratio: 0.4 + +# SIVE rendering +sive: + attn_cfg: # init content via attn + cross_attn_res: 16 + self_attn_res: 32 + max_com: 20 + mean_comp: False + comp_idx: 0 + attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn + bg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 20 + coord_init: 'random' # 'sparse', 'random', 'naive'. place the first control point + grid: 20 + # optim + lr_schedule: True + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.001 + fg: + style: "iconography" # 'iconography' ,"pixelart", "sketch", 'painting', 'ink' + num_iter: 10 + num_paths: 256 # number of strokes + path_schedule: 'repeat' # 'repeat', 'list' + schedule_each: 128 + width: 3 # sketch stroke width + num_segments: 4 + segment_init: 'circle' # 'random' + radius: 15 + coord_init: 'random' # 'random', 'naive', place the first control point + grid: 20 + # optim + lr_schedule: False + optim_bg: False # train background + use_attn_init: True + softmax_tau: 0.3 # temperature of softmax + # loss + use_distance_weighted_loss: False + xing_loss_weight: 0.01 + tog: # for refinement + reinit: True # if False, use fg params to init content + num_iter: 1000 + # optim + lr_schedule: False # enable lr_scheduler or not + # loss + bg_lam: 0 + fg_lam: 1 + xing_loss_weight: 0 + +# VPSD primitives +num_paths: 128 # number of strokes +trainable_bg: False # set the background to be trainable +width: 3 # stroke width +num_segments: 4 +segment_init: 'circle' # 'random' +radius: 20 +coord_init: 'random' # 'random', 'naive', 'sparse' place the first control point +grid: 50 # divide the canvas into n grids +path_reinit: # reinitializing paths + use: True + freq: 100 # every 50 iterations + stop_step: 1000 # for VPSD fine-tuning + opacity_threshold: 0.05 + area_threshold: 64 + +# lr and optim +vpsd_stage_optim: + point: 1 + width: 0.1 + color: 0.01 + bg: 0.01 + lr_schedule: True # use lr_scheduler + optim: + name: 'adam' + betas: [ 0.9, 0.9 ] + eps: 1e-6 + schedule: + name: 'cosine' + warmup_steps: 10 + warmup_start_lr: 0.02 + warmup_end_lr: 0.9 + cosine_end_lr: 0.4 + +# stable diffusion in VPSD stage +vpsd_model_cfg: + model_id: "sd21b" # sd14, sd15, sd21, sd21b, sdxl + ldm_speed_up: False + enable_xformers: True + gradient_checkpoint: False + cpu_offload: True + num_inference_steps: 100 + guidance_scale: 7.5 # sdxl default 5.0 + lora_path: ~ + +# VPSD setting +vpsd: + type: 'vpsd' + n_particle: 6 # 4, 8, 16 + vsd_n_particle: 4 # the batch size of particles + particle_aug: False # do data enhancement for the input particles + num_iter: 2000 # total iterations + guidance_scale: 7.5 # CFG value + grad_scale: 1.0 # increase or decrease the gradient + grad_clip_val: ~ # eg: 10, clip the gradient of VPSD + t_range: [ 0.02, 0.98 ] + # 'randint': random time steps, this may have a more authentic style. + # 'max_0.5_900': annealing from 0.98 to 0.5 after 900 steps, this may have a more colorful results. + t_schedule: 'randint' # or 'randint' + # phi model config + phi_single: False # if False new an unet model to estimate noise + phi_model: 'lora' # 'lora', 'unet_simple' + use_attn_scale: ${x.vpsd.phi_single} # use lora_attn_scale or not + lora_attn_scale: 1.0 # the scale of the attn based lora layer + phi_guidance_scale: 1.0 + phi_t: False # different t for phi fine-tuning + phi_update_step: 1 # enable multi-update phi model or not + phi_lr: 0.0001 # learning rate of phi model + phi_scheduler: 'ddim' + phi_n_particle: 2 # the batch size of phi_model + # ReFL config + phi_ReFL: False # enable reward feed back learning + n_phi_sample: 1 # number of samples used in ReFL + phi_sample_step: 200 # the phi log step + phi_infer_step: 50 # the phi num_inference_steps + # phi model optim + phi_optim: + name: 'adamw' + betas: [ 0.9, 0.999 ] + eps: 1e-8 + weight_decay: ~ # 1e-5 + # phi model lr learning schedule + phi_schedule: + use: False + name: 'cosine' + warmup_steps: 50 + warmup_start_lr: 0.00001 + warmup_end_lr: 0.0001 + total_step: 800 + cosine_end_lr: 0.0001 + +# reward model +reward_path: './checkpoint/ImageReward' + +# xing loss for closed-form paths +xing_loss: + use: False + weight: 0.01 diff --git a/svgdreamer.py b/svgdreamer.py new file mode 100644 index 0000000000000000000000000000000000000000..46c4bd759ce37d4f729d97faf7eebdd291783341 --- /dev/null +++ b/svgdreamer.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Author: ximing xing +# Description: the main func of this project. +# Copyright (c) 2023, XiMing Xing. + +import os +import sys +from functools import partial + +from accelerate.utils import set_seed +import hydra +import omegaconf + +sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0]) + +from svgdreamer.utils import render_batch_wrap, get_seed_range +from svgdreamer.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline + + +@hydra.main(version_base=None, config_path="conf", config_name='config') +def main(cfg: omegaconf.DictConfig): + """ + The project configuration is stored in './conf/config.yaml’ + And style configurations are stored in './conf/x/iconographic.yaml’ + """ + + # set seed + set_seed(cfg.seed) + seed_range = get_seed_range(cfg.srange) if cfg.multirun else None + + # render function + render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range) + + if not cfg.multirun: # generate SVG multiple times + pipe = SVGDreamerPipeline(cfg) + pipe.painterly_rendering(cfg.prompt) + else: # generate many SVG at once + render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None) + + +if __name__ == '__main__': + main() diff --git a/svgdreamer/__init__.py b/svgdreamer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c96d34996877f0e807321fbfca4d4081f6844170 --- /dev/null +++ b/svgdreamer/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Copyright (c) 2023, XiMing Xing. +# License: MIT + +__version__ = "1.0" diff --git a/svgdreamer/diffusers_warp/__init__.py b/svgdreamer/diffusers_warp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9657b1234ca9fceb808eb9deda0e696c9b5fbbc2 --- /dev/null +++ b/svgdreamer/diffusers_warp/__init__.py @@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: +from typing import AnyStr +import pathlib +from collections import OrderedDict +from packaging import version + +import torch +from diffusers import StableDiffusionPipeline, SchedulerMixin +from diffusers import UNet2DConditionModel +from diffusers.utils import is_torch_version, is_xformers_available + +DiffusersModels = OrderedDict({ + "sd14": "CompVis/stable-diffusion-v1-4", # resolution: 512 + "sd15": "runwayml/stable-diffusion-v1-5", # resolution: 512 + "sd21b": "stabilityai/stable-diffusion-2-1-base", # resolution: 512 + "sd21": "stabilityai/stable-diffusion-2-1", # resolution: 768 + "sdxl": "stabilityai/stable-diffusion-xl-base-1.0", # resolution: 1024 +}) + +# default resolution +_model2resolution = { + "sd14": 512, + "sd15": 512, + "sd21b": 512, + "sd21": 768, + "sdxl": 1024, +} + + +def model2res(model_id: str): + return _model2resolution.get(model_id, 512) + + +def init_StableDiffusion_pipeline(model_id: AnyStr, + custom_pipeline: StableDiffusionPipeline, + custom_scheduler: SchedulerMixin = None, + device: torch.device = "cuda", + torch_dtype: torch.dtype = torch.float32, + local_files_only: bool = True, + force_download: bool = False, + resume_download: bool = False, + ldm_speed_up: bool = False, + enable_xformers: bool = True, + gradient_checkpoint: bool = False, + cpu_offload: bool = False, + vae_slicing: bool = False, + lora_path: AnyStr = None, + unet_path: AnyStr = None) -> StableDiffusionPipeline: + """ + A tool for initial diffusers pipeline. + + Args: + model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path + custom_pipeline: any StableDiffusionPipeline pipeline + custom_scheduler: any scheduler + device: set device + torch_dtype: data type + local_files_only: prohibited download model + force_download: forced download model + resume_download: re-download model + ldm_speed_up: use the `torch.compile` api to speed up unet + enable_xformers: enable memory efficient attention from [xFormers] + gradient_checkpoint: activates gradient checkpointing for the current model + cpu_offload: enable sequential cpu offload + vae_slicing: enable sliced VAE decoding + lora_path: load LoRA checkpoint + unet_path: load unet checkpoint + + Returns: + diffusers.StableDiffusionPipeline + """ + + # get model id + model_id = DiffusersModels.get(model_id, model_id) + + # process diffusion model + if custom_scheduler is not None: + pipeline = custom_pipeline.from_pretrained( + model_id, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + force_download=force_download, + resume_download=resume_download, + scheduler=custom_scheduler.from_pretrained(model_id, + subfolder="scheduler", + local_files_only=local_files_only, + force_download=force_download, + resume_download=resume_download) + ).to(device) + else: + pipeline = custom_pipeline.from_pretrained( + model_id, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + force_download=force_download, + resume_download=resume_download, + ).to(device) + + print(f"load diffusers pipeline: {model_id}") + + # process unet model if exist + if unet_path is not None and pathlib.Path(unet_path).exists(): + print(f"=> load u-net from {unet_path}") + pipeline.unet.from_pretrained(model_id, subfolder="unet") + + # process lora layers if exist + if lora_path is not None and pathlib.Path(lora_path).exists(): + pipeline.unet.load_attn_procs(lora_path) + print(f"=> load lora layers into U-Net from {lora_path} ...") + + # torch.compile + if ldm_speed_up: + if is_torch_version(">=", "2.0.0"): + pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True) + print(f"=> enable torch.compile on U-Net") + else: + print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0") + + # Meta xformers + if enable_xformers: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + print( + "xFormers 0.0.16 cannot be used for training in some GPUs. " + "If you observe problems during training, please update xFormers to at least 0.0.17. " + "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + print(f"=> enable xformers") + pipeline.unet.enable_xformers_memory_efficient_attention() + else: + print(f"=> warning: xformers is not available.") + + # gradient checkpointing + if gradient_checkpoint: + # if pipeline.unet.is_gradient_checkpointing: + if True: + print(f"=> enable gradient checkpointing") + pipeline.unet.enable_gradient_checkpointing() + else: + print("=> waring: gradient checkpointing is not activated for this model.") + + if cpu_offload: + pipeline.enable_sequential_cpu_offload() + + if vae_slicing: + pipeline.enable_vae_slicing() + + print(pipeline.scheduler) + return pipeline + + +def init_diffusers_unet(model_id: AnyStr, + device: torch.device = "cuda", + torch_dtype: torch.dtype = torch.float32, + local_files_only: bool = True, + force_download: bool = False, + resume_download: bool = False, + ldm_speed_up: bool = False, + enable_xformers: bool = True, + gradient_checkpoint: bool = False, + lora_path: AnyStr = None, + unet_path: AnyStr = None): + """ + A tool for initial diffusers UNet model. + + Args: + model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path + device: set device + torch_dtype: data type + local_files_only: prohibited download model + force_download: forced download model + resume_download: re-download model + ldm_speed_up: use the `torch.compile` api to speed up unet + enable_xformers: enable memory efficient attention from [xFormers] + gradient_checkpoint: activates gradient checkpointing for the current model + lora_path: load LoRA checkpoint + unet_path: load unet checkpoint + + Returns: + diffusers.UNet + """ + + # get model id + model_id = DiffusersModels.get(model_id, model_id) + + # process UNet model + unet = UNet2DConditionModel.from_pretrained( + model_id, + subfolder="unet", + torch_dtype=torch_dtype, + local_files_only=local_files_only, + force_download=force_download, + resume_download=resume_download, + ).to(device) + + print(f"load diffusers UNet: {model_id}") + + # process unet model if exist + if unet_path is not None and pathlib.Path(unet_path).exists(): + print(f"=> load u-net from {unet_path}") + unet.from_pretrained(model_id) + + # process lora layers if exist + if lora_path is not None and pathlib.Path(lora_path).exists(): + unet.load_attn_procs(lora_path) + print(f"=> load lora layers into U-Net from {lora_path} ...") + + # torch.compile + if ldm_speed_up: + if is_torch_version(">=", "2.0.0"): + unet = torch.compile(unet, mode="reduce-overhead", fullgraph=True) + print(f"=> enable torch.compile on U-Net") + else: + print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0") + + # Meta xformers + if enable_xformers: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + print( + "xFormers 0.0.16 cannot be used for training in some GPUs. " + "If you observe problems during training, please update xFormers to at least 0.0.17. " + "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + print(f"=> enable xformers") + unet.enable_xformers_memory_efficient_attention() + else: + print(f"=> warning: xformers is not available.") + + # gradient checkpointing + if gradient_checkpoint: + # if unet.is_gradient_checkpointing: + if True: + print(f"=> enable gradient checkpointing") + unet.enable_gradient_checkpointing() + else: + print("=> waring: gradient checkpointing is not activated for this model.") + + return unet diff --git a/svgdreamer/diffvg_warp/__init__.py b/svgdreamer/diffvg_warp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a97c9500f48e17361ce5ba881bc76d2322f5d7ce --- /dev/null +++ b/svgdreamer/diffvg_warp/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + +from .diffvg_state import DiffVGState, init_pydiffvg + +__all__ = [ + 'DiffVGState', + 'init_pydiffvg' +] diff --git a/svgdreamer/diffvg_warp/diffvg_state.py b/svgdreamer/diffvg_warp/diffvg_state.py new file mode 100644 index 0000000000000000000000000000000000000000..78bf82a7ccb0ad56ae0276ca8ea3b35f99887a85 --- /dev/null +++ b/svgdreamer/diffvg_warp/diffvg_state.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: parent class +# Copyright (c) 2023, XiMing Xing. +# License: MIT License +import pathlib +from typing import AnyStr, List, Union +import xml.etree.ElementTree as etree + +import torch +import pydiffvg + + +def init_pydiffvg(device: torch.device, + use_gpu: bool = torch.cuda.is_available(), + print_timing: bool = False): + pydiffvg.set_use_gpu(use_gpu) + pydiffvg.set_device(device) + pydiffvg.set_print_timing(print_timing) + + +class DiffVGState(torch.nn.Module): + + def __init__(self, + device: torch.device, + use_gpu: bool = torch.cuda.is_available(), + print_timing: bool = False, + canvas_width: int = None, + canvas_height: int = None): + super(DiffVGState, self).__init__() + # pydiffvg device setting + self.device = device + init_pydiffvg(device, use_gpu, print_timing) + + # canvas size + self.canvas_width = canvas_width + self.canvas_height = canvas_height + + # record all paths + self.shapes = [] + self.shape_groups = [] + # record the current optimized path + self.cur_shapes = [] + self.cur_shape_groups = [] + + # learnable SVG params + self.point_vars = [] + self.color_vars = [] + self.width_vars = [] + + def clip_curve_shape(self, *args, **kwargs): + raise NotImplementedError + + def render_warp(self, seed=0): + self.clip_curve_shape() + + scene_args = pydiffvg.RenderFunction.serialize_scene( + self.canvas_width, self.canvas_height, self.shapes, self.shape_groups + ) + _render = pydiffvg.RenderFunction.apply + img = _render(self.canvas_width, # width + self.canvas_height, # height + 2, # num_samples_x + 2, # num_samples_y + seed, # seed + None, + *scene_args) + return img + + def render_image(self, canvas_width, canvas_height, shapes, shape_groups, seed: int = 0): + scene_args = pydiffvg.RenderFunction.serialize_scene( + canvas_width, canvas_height, shapes, shape_groups + ) + _render = pydiffvg.RenderFunction.apply + img = _render(canvas_width, # width + canvas_height, # height + 2, # num_samples_x + 2, # num_samples_y + seed, # seed + None, + *scene_args) + img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4]) + img = img.unsqueeze(0) # convert img from HWC to NCHW + img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW + return img + + @staticmethod + def load_svg(path_svg): + canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg) + return canvas_width, canvas_height, shapes, shape_groups + + def save_svg(self, + filename: Union[AnyStr, pathlib.Path], + width: int = None, + height: int = None, + shapes: List = None, + shape_groups: List = None, + use_gamma: bool = False, + background: str = None): + """ + Save an SVG file with specified parameters and shapes. + Noting: New version of SVG saving function that is an adaptation of pydiffvg.save_svg. + The original version saved words resulting in incomplete glyphs. + + Args: + filename (str): The path to save the SVG file. + width (int): The width of the SVG canvas. + height (int): The height of the SVG canvas. + shapes (list): A list of shapes to be included in the SVG. + shape_groups (list): A list of shape groups. + use_gamma (bool): Flag indicating whether to apply gamma correction. + background (str, optional): The background color of the SVG. + + Returns: + None + """ + root = etree.Element('svg') + root.set('version', '1.1') + root.set('xmlns', 'http://www.w3.org/2000/svg') + root.set('width', str(width)) + root.set('height', str(height)) + + if background is not None: + print(f"setting background to {background}") + root.set('style', str(background)) + + defs = etree.SubElement(root, 'defs') + g = etree.SubElement(root, 'g') + + if use_gamma: + f = etree.SubElement(defs, 'filter') + f.set('id', 'gamma') + f.set('x', '0') + f.set('y', '0') + f.set('width', '100%') + f.set('height', '100%') + gamma = etree.SubElement(f, 'feComponentTransfer') + gamma.set('color-interpolation-filters', 'sRGB') + feFuncR = etree.SubElement(gamma, 'feFuncR') + feFuncR.set('type', 'gamma') + feFuncR.set('amplitude', str(1)) + feFuncR.set('exponent', str(1 / 2.2)) + feFuncG = etree.SubElement(gamma, 'feFuncG') + feFuncG.set('type', 'gamma') + feFuncG.set('amplitude', str(1)) + feFuncG.set('exponent', str(1 / 2.2)) + feFuncB = etree.SubElement(gamma, 'feFuncB') + feFuncB.set('type', 'gamma') + feFuncB.set('amplitude', str(1)) + feFuncB.set('exponent', str(1 / 2.2)) + feFuncA = etree.SubElement(gamma, 'feFuncA') + feFuncA.set('type', 'gamma') + feFuncA.set('amplitude', str(1)) + feFuncA.set('exponent', str(1 / 2.2)) + g.set('style', 'filter:url(#gamma)') + + # Store color + for i, shape_group in enumerate(shape_groups): + def add_color(shape_color, name): + if isinstance(shape_color, pydiffvg.LinearGradient): + lg = shape_color + color = etree.SubElement(defs, 'linearGradient') + color.set('id', name) + color.set('x1', str(lg.begin[0].item())) + color.set('y1', str(lg.begin[1].item())) + color.set('x2', str(lg.end[0].item())) + color.set('y2', str(lg.end[1].item())) + offsets = lg.offsets.data.cpu().numpy() + stop_colors = lg.stop_colors.data.cpu().numpy() + for j in range(offsets.shape[0]): + stop = etree.SubElement(color, 'stop') + stop.set('offset', str(offsets[j])) + c = lg.stop_colors[j, :] + stop.set('stop-color', 'rgb({}, {}, {})'.format( + int(255 * c[0]), int(255 * c[1]), int(255 * c[2]) + )) + stop.set('stop-opacity', '{}'.format(c[3])) + if isinstance(shape_color, pydiffvg.RadialGradient): + lg = shape_color + color = etree.SubElement(defs, 'radialGradient') + color.set('id', name) + color.set('cx', str(lg.center[0].item() / width)) + color.set('cy', str(lg.center[1].item() / height)) + # this only support width=height + color.set('r', str(lg.radius[0].item() / width)) + offsets = lg.offsets.data.cpu().numpy() + stop_colors = lg.stop_colors.data.cpu().numpy() + for j in range(offsets.shape[0]): + stop = etree.SubElement(color, 'stop') + stop.set('offset', str(offsets[j])) + c = lg.stop_colors[j, :] + stop.set('stop-color', 'rgb({}, {}, {})'.format( + int(255 * c[0]), int(255 * c[1]), int(255 * c[2]) + )) + stop.set('stop-opacity', '{}'.format(c[3])) + + if shape_group.fill_color is not None: + add_color(shape_group.fill_color, 'shape_{}_fill'.format(i)) + if shape_group.stroke_color is not None: + add_color(shape_group.stroke_color, 'shape_{}_stroke'.format(i)) + + for i, shape_group in enumerate(shape_groups): + shape = shapes[shape_group.shape_ids[0]] + if isinstance(shape, pydiffvg.Circle): + shape_node = etree.SubElement(g, 'circle') + shape_node.set('r', str(shape.radius.item())) + shape_node.set('cx', str(shape.center[0].item())) + shape_node.set('cy', str(shape.center[1].item())) + elif isinstance(shape, pydiffvg.Polygon): + shape_node = etree.SubElement(g, 'polygon') + points = shape.points.data.cpu().numpy() + path_str = '' + for j in range(0, shape.points.shape[0]): + path_str += '{} {}'.format(points[j, 0], points[j, 1]) + if j != shape.points.shape[0] - 1: + path_str += ' ' + shape_node.set('points', path_str) + elif isinstance(shape, pydiffvg.Path): + for j, id in enumerate(shape_group.shape_ids): + shape = shapes[id] + if isinstance(shape, pydiffvg.Path): + if j == 0: + shape_node = etree.SubElement(g, 'path') + node_id = shape_node.get('id') + path_str = '' + + num_segments = shape.num_control_points.shape[0] + num_control_points = shape.num_control_points.data.cpu().numpy() + points = shape.points.data.cpu().numpy() + num_points = shape.points.shape[0] + path_str += 'M {} {}'.format(points[0, 0], points[0, 1]) + point_id = 1 + for j in range(0, num_segments): + if num_control_points[j] == 0: + p = point_id % num_points + path_str += ' L {} {}'.format( + points[p, 0], points[p, 1]) + point_id += 1 + elif num_control_points[j] == 1: + p1 = (point_id + 1) % num_points + path_str += ' Q {} {} {} {}'.format( + points[point_id, 0], points[point_id, 1], + points[p1, 0], points[p1, 1]) + point_id += 2 + elif num_control_points[j] == 2: + p2 = (point_id + 2) % num_points + path_str += ' C {} {} {} {} {} {}'.format( + points[point_id, 0], points[point_id, 1], + points[point_id + 1, 0], points[point_id + 1, 1], + points[p2, 0], points[p2, 1]) + point_id += 3 + if node_id is not None: + shape_node.set('id', node_id) # add id to Path + shape_node.set('d', path_str) + elif isinstance(shape, pydiffvg.Rect): + shape_node = etree.SubElement(g, 'rect') + shape_node.set('x', str(shape.p_min[0].item())) + shape_node.set('y', str(shape.p_min[1].item())) + shape_node.set('width', str(shape.p_max[0].item() - shape.p_min[0].item())) + shape_node.set('height', str(shape.p_max[1].item() - shape.p_min[1].item())) + elif isinstance(shape, pydiffvg.Ellipse): + shape_node = etree.SubElement(g, 'ellipse') + shape_node.set('cx', str(shape.center[0].item())) + shape_node.set('cy', str(shape.center[1].item())) + shape_node.set('rx', str(shape.radius[0].item())) + shape_node.set('ry', str(shape.radius[1].item())) + else: + raise NotImplementedError(f'shape type: {type(shape)} is not involved in pydiffvg.') + + shape_node.set('stroke-width', str(2 * shape.stroke_width.data.cpu().item())) + if shape_group.fill_color is not None: + if isinstance(shape_group.fill_color, pydiffvg.LinearGradient): + shape_node.set('fill', 'url(#shape_{}_fill)'.format(i)) + else: + c = shape_group.fill_color.data.cpu().numpy() + shape_node.set('fill', 'rgb({}, {}, {})'.format( + int(255 * c[0]), int(255 * c[1]), int(255 * c[2]))) + shape_node.set('opacity', str(c[3])) + else: + shape_node.set('fill', 'none') + if shape_group.stroke_color is not None: + if isinstance(shape_group.stroke_color, pydiffvg.LinearGradient): + shape_node.set('stroke', 'url(#shape_{}_stroke)'.format(i)) + else: + c = shape_group.stroke_color.data.cpu().numpy() + shape_node.set('stroke', 'rgb({}, {}, {})'.format( + int(255 * c[0]), int(255 * c[1]), int(255 * c[2]))) + shape_node.set('stroke-opacity', str(c[3])) + shape_node.set('stroke-linecap', 'round') + shape_node.set('stroke-linejoin', 'round') + + with open(filename, "w") as f: + f.write(pydiffvg.prettify(root)) + + @staticmethod + def save_image(img, filename, gamma=1): + if torch.is_tensor(img) and torch.device != 'cpu': + img = img.detach().cpu() + pydiffvg.imwrite(img, filename, gamma=gamma) diff --git a/svgdreamer/libs/__init__.py b/svgdreamer/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c787c0c196a096fac20deb40e3943726731bbee5 --- /dev/null +++ b/svgdreamer/libs/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: a self consistent system, +# including runner, trainer, loss function, EMA, optimizer, lr scheduler , and common utils. + +from .model_state import ModelState +from .optim import get_optimizer diff --git a/svgdreamer/libs/logging.py b/svgdreamer/libs/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..dd9828d4ad6b640cc0dd08583b4b762195ef1d96 --- /dev/null +++ b/svgdreamer/libs/logging.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + +import os +import sys +import errno + + +def get_logger(logs_dir: str, file_name: str = "log.txt"): + logger = PrintLogger(os.path.join(logs_dir, file_name)) + sys.stdout = logger # record all python print + return logger + + +class PrintLogger(object): + + def __init__(self, fpath=None): + """ + python standard input/output records + """ + self.console = sys.stdout + self.file = None + if fpath is not None: + mkdir_if_missing(os.path.dirname(fpath)) + self.file = open(fpath, 'w') + + def __del__(self): + self.close() + + def __enter__(self): + pass + + def __exit__(self, *args): + self.close() + + def write(self, msg): + self.console.write(msg) + if self.file is not None: + self.file.write(msg) + + def write_in(self, msg): + """write in log only, not console""" + if self.file is not None: + self.file.write(msg) + + def flush(self): + self.console.flush() + if self.file is not None: + self.file.flush() + os.fsync(self.file.fileno()) + + def close(self): + self.console.close() + if self.file is not None: + self.file.close() + + +def mkdir_if_missing(dir_path): + try: + os.makedirs(dir_path) + except OSError as e: + if e.errno != errno.EEXIST: + raise diff --git a/svgdreamer/libs/model_state.py b/svgdreamer/libs/model_state.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7d1e2370d205d8721f4f9ec8eb91811a09c2a7 --- /dev/null +++ b/svgdreamer/libs/model_state.py @@ -0,0 +1,253 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + +from typing import Union, List +from pathlib import Path +from datetime import datetime +import logging + +from omegaconf import OmegaConf, DictConfig +from pprint import pprint +import torch +from accelerate import Accelerator + +from .logging import get_logger + + +class ModelState: + """ + Handling logger and `hugging face` accelerate training + + features: + - Precision + - Device + - Optimizer + - Logger (default: python system print and logging) + - Monitor (default: wandb, tensorboard) + """ + + def __init__( + self, + args: DictConfig, + log_path_suffix: str = None, + ignore_log=False, # whether to create log file or not + ) -> None: + self.args: DictConfig = args + # set cfg + self.state_cfg = args.state + self.x_cfg = args.x + + """check valid""" + mixed_precision = self.state_cfg.get("mprec") + # Bug: omegaconf convert 'no' to false + mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision + + """create working space""" + # rule: ['./config'. 'method_name', 'exp_name.yaml'] + # -> result_path: ./runs/{method_name}-{exp_name}, as a base folder + now_time = datetime.now().strftime('%Y-%m-%d-%H-%M') + results_folder = self.args.get("result_path", None) + if results_folder is None: + self.result_path = Path("./workdir") / f"SVGDreamer-{now_time}" + else: + self.result_path = Path(results_folder) / f"SVGDreamer-{now_time}" + + # update result_path: ./runs/{method_name}-{exp_name}/{log_path_suffix} + # noting: can be understood as "results dir / methods / ablation study / your result" + if log_path_suffix is not None: + self.result_path = self.result_path / f"{log_path_suffix}" + else: + self.result_path = self.result_path / f"SVGDreamer" + + """init visualized tracker""" + # TODO: monitor with WANDB or TENSORBOARD + self.log_with = [] + # if self.state_cfg.wandb: + # self.log_with.append(LoggerType.WANDB) + # if self.state_cfg.tensorboard: + # self.log_with.append(LoggerType.TENSORBOARD) + + """HuggingFace Accelerator""" + self.accelerator = Accelerator( + device_placement=True, + mixed_precision=mixed_precision, + cpu=True if self.state_cfg.cpu else False, + log_with=None if len(self.log_with) == 0 else self.log_with, + project_dir=self.result_path / "vis", + ) + + """logs""" + if self.accelerator.is_local_main_process: + # logging + self.log = logging.getLogger(__name__) + + # log results in a folder periodically + self.result_path.mkdir(parents=True, exist_ok=True) + if not ignore_log: + self.logger = get_logger( + logs_dir=self.result_path.as_posix(), + file_name=f"{now_time}-{args.seed}-log.txt" + ) + + print("==> system args: ") + sys_cfg = OmegaConf.masked_copy(args, ["x"]) + print(sys_cfg) + print("==> yaml config args: ") + print(self.x_cfg) + + print("\n***** Model State *****") + print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp}") + print(f"-> Weight dtype: {self.weight_dtype}") + + if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled: + print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}") + + print(f"-> Working Space: '{self.result_path}'") + + """glob step""" + self.step = 0 + + """log process""" + self.accelerator.wait_for_everyone() + print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}') + + self.print("-> state initialization complete \n") + + @property + def device(self): + return self.accelerator.device + + @property + def is_main_process(self): + return self.accelerator.is_main_process + + @property + def weight_dtype(self): + weight_dtype = torch.float32 + if self.accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + return weight_dtype + + @property + def n_gpus(self): + return self.accelerator.num_processes + + @property + def no_decay_params_names(self): + no_decay = [ + "bn", "LayerNorm", "GroupNorm", + ] + return no_decay + + def no_decay_params(self, model, weight_decay): + """optimization tricks""" + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in model.named_parameters() + if not any(nd in n for nd in self.no_decay_params_names) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() + if any(nd in n for nd in self.no_decay_params_names) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + def optimized_params(self, model: torch.nn.Module, verbose=True) -> List: + """return parameters if `requires_grad` is True + + Args: + model: pytorch models + verbose: log optimized parameters + + Examples: + >>> params_optimized = self.optimized_params(uvit, verbose=True) + >>> optimizer = torch.optim.AdamW(params_optimized, lr=1e-3) + + Returns: + a list of parameters + """ + params_optimized = [] + for key, value in model.named_parameters(): + if value.requires_grad: + params_optimized.append(value) + if verbose: + self.print("\t {}, {}, {}".format(key, value.numel(), value.shape)) + return params_optimized + + def save_everything(self, fpath: str): + """Saving and loading the model, optimizer, RNG generators, and the GradScaler.""" + if not self.accelerator.is_main_process: + return + self.accelerator.save_state(fpath) + + def load_save_everything(self, fpath: str): + """Loading the model, optimizer, RNG generators, and the GradScaler.""" + self.accelerator.load_state(fpath) + + def save(self, milestone: Union[str, float, int], checkpoint: object) -> None: + if not self.accelerator.is_main_process: + return + + torch.save(checkpoint, self.result_path / f'model-{milestone}.pt') + + def save_in(self, root: Union[str, Path], checkpoint: object) -> None: + if not self.accelerator.is_main_process: + return + + torch.save(checkpoint, root) + + def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False): + ckpt = torch.load(path, map_location=self.device) + + unwrapped_model = self.accelerator.unwrap_model(model) + if rm_module_prefix: + unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()}) + else: + unwrapped_model.load_state_dict(ckpt) + return unwrapped_model + + def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]): + ckpt = torch.load(path, map_location=self.accelerator.device) + self.print(f"pretrained_dict len: {len(ckpt)}") + unwrapped_model = self.accelerator.unwrap_model(model) + model_dict = unwrapped_model.state_dict() + pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict} + model_dict.update(pretrained_dict) + unwrapped_model.load_state_dict(model_dict, strict=False) + self.print(f"selected pretrained_dict: {len(model_dict)}") + return unwrapped_model + + def print(self, *args, **kwargs): + """Use in replacement of `print()` to only print once per server.""" + self.accelerator.print(*args, **kwargs) + + def pretty_print(self, msg): + if self.accelerator.is_main_process: + pprint(dict(msg)) + + def close_tracker(self): + self.accelerator.end_training() + + def free_memory(self): + self.accelerator.clear() + + def close(self, msg: str = "Training complete."): + """Use in end of training.""" + self.free_memory() + + if torch.cuda.is_available(): + self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB') + if len(self.log_with) > 0: + self.close_tracker() + self.print(msg) diff --git a/svgdreamer/libs/optim.py b/svgdreamer/libs/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..61b5e1b646848b02223d509298fb9ca1b29b8982 --- /dev/null +++ b/svgdreamer/libs/optim.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: optimizers +# Copyright (c) 2023, XiMing Xing. +# License: MIT License +from functools import partial + +import torch +from omegaconf import DictConfig + + +def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None): + param_dict = {} + if optimizer_name == "adam": + optimizer = partial(torch.optim.Adam, params=parameters) + if lr is not None: + optimizer = partial(torch.optim.Adam, params=parameters, lr=lr) + if config.get('betas'): + param_dict['betas'] = config.betas + if config.get('weight_decay'): + param_dict['weight_decay'] = config.weight_decay + if config.get('eps'): + param_dict['eps'] = config.eps + elif optimizer_name == "adamW": + optimizer = partial(torch.optim.AdamW, params=parameters) + if lr is not None: + optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr) + if config.get('betas'): + param_dict['betas'] = config.betas + if config.get('weight_decay'): + param_dict['weight_decay'] = config.weight_decay + if config.get('eps'): + param_dict['eps'] = config.eps + elif optimizer_name == "radam": + optimizer = partial(torch.optim.RAdam, params=parameters) + if lr is not None: + optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr) + if config.get('betas'): + param_dict['betas'] = config.betas + if config.get('weight_decay'): + param_dict['weight_decay'] = config.weight_decay + elif optimizer_name == "sgd": + optimizer = partial(torch.optim.SGD, params=parameters) + if lr is not None: + optimizer = partial(torch.optim.SGD, params=parameters, lr=lr) + if config.get('momentum'): + param_dict['momentum'] = config.momentum + if config.get('weight_decay'): + param_dict['weight_decay'] = config.weight_decay + if config.get('nesterov'): + param_dict['nesterov'] = config.nesterov + else: + raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.") + + if len(param_dict.keys()) > 0: + return optimizer(**param_dict) + else: + return optimizer() diff --git a/svgdreamer/painter/VPSD_pipeline.py b/svgdreamer/painter/VPSD_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6d51db70412ec8cdd91faa4ff051e8c1791e248b --- /dev/null +++ b/svgdreamer/painter/VPSD_pipeline.py @@ -0,0 +1,585 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: +import re +import PIL +from PIL import Image +from typing import Any, List, Optional, Union, Dict +from omegaconf import DictConfig + +import numpy as np +import torch +import torch.nn.functional as F +from torchvision import transforms +from diffusers import StableDiffusionPipeline, UNet2DConditionModel +from diffusers import DDIMScheduler +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + rescale_noise_cfg, StableDiffusionPipelineOutput) +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.loaders import AttnProcsLayers + +from svgdreamer.diffusers_warp import init_StableDiffusion_pipeline, init_diffusers_unet + + +class VectorizedParticleSDSPipeline(torch.nn.Module): + + def __init__(self, model_cfg: DictConfig, diffuser_cfg: DictConfig, guidance_cfg: DictConfig, device: torch.device): + super().__init__() + self.device = device + assert guidance_cfg.n_particle >= guidance_cfg.vsd_n_particle + assert guidance_cfg.n_particle >= guidance_cfg.phi_n_particle + + pipe_kwargs = { + "device": self.device, + "torch_dtype": torch.float32, + "local_files_only": not diffuser_cfg.download, + "force_download": diffuser_cfg.force_download, + "resume_download": diffuser_cfg.resume_download, + "ldm_speed_up": model_cfg.ldm_speed_up, + "enable_xformers": model_cfg.enable_xformers, + "gradient_checkpoint": model_cfg.gradient_checkpoint, + "cpu_offload": model_cfg.cpu_offload, + "vae_slicing": False + } + + # load pretrained model + self.sd_pipeline = init_StableDiffusion_pipeline( + model_cfg.model_id, + custom_pipeline=StableDiffusionPipeline, + custom_scheduler=DDIMScheduler, + **pipe_kwargs + ) + # disable grads + self.sd_pipeline.vae.requires_grad_(False) + self.sd_pipeline.text_encoder.requires_grad_(False) + self.sd_pipeline.unet.requires_grad_(False) + # set components + self.vae = self.sd_pipeline.vae + self.unet = self.sd_pipeline.unet + self.scheduler = self.sd_pipeline.scheduler + self.tokenizer = self.sd_pipeline.tokenizer + self.text_encoder = self.sd_pipeline.text_encoder + + if guidance_cfg.phi_model == 'lora': + if guidance_cfg.phi_single: # default, use the single unet + # load LoRA model from the pretrained model + unet_ = self.unet + else: + # create a new unet model + pipe_kwargs.pop('cpu_offload') + pipe_kwargs.pop('vae_slicing') + unet_ = init_diffusers_unet(model_cfg.model_id, **pipe_kwargs) + + # set correct LoRA layers + self.unet_phi, phi_model_layers = self.set_lora_layers(unet_) + self.phi_params = list(phi_model_layers.parameters()) + self.lora_cross_attention_kwargs = {"scale": guidance_cfg.lora_attn_scale} \ + if guidance_cfg.use_attn_scale else {} + self.vae_phi = self.vae + self.vae_phi.requires_grad_(False) + + elif guidance_cfg.phi_model == 'unet_simple': + self.unet_phi = UNet2DConditionModel( + sample_size=64, + in_channels=4, + out_channels=4, + layers_per_block=1, + block_out_channels=(128, 256, 384, 512), + down_block_types=( + "DownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ), + up_block_types=( + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + ), + cross_attention_dim=self.unet.config.cross_attention_dim + ).to(device) + self.phi_params = list(self.unet_phi.parameters()) + self.vae_phi = self.vae + # reset lora + guidance_cfg.use_attn_scale = False + guidance_cfg.lora_attn_scale = False + + # hyper-params + self.phi_single = guidance_cfg.phi_single + self.guidance_scale: float = guidance_cfg.guidance_scale + self.guidance_scale_lora: float = guidance_cfg.phi_guidance_scale + self.grad_clip_val: Union[float, None] = guidance_cfg.grad_clip_val + self.vsd_n_particle: int = guidance_cfg.vsd_n_particle + self.phi_n_particle: int = guidance_cfg.phi_n_particle + self.t_schedule: str = guidance_cfg.t_schedule + self.t_range = list(guidance_cfg.t_range) + print( + f'n_particles: {guidance_cfg.n_particle}, ' + f'enhance_particles: {guidance_cfg.particle_aug}, ' + f'n_particles of score: {self.vsd_n_particle}, ' + f'n_particles of phi_model: {self.phi_n_particle}, \n' + f't_range: {self.t_range}, ' + f't_schedule: {self.t_schedule}, \n' + f'guidance_scale: {self.guidance_scale}, phi_guidance_scale: {self.guidance_scale_lora}.' + ) + print(f"phi_model: {guidance_cfg.phi_model}, " + f"use lora_cross_attn: {guidance_cfg.use_attn_scale}, " + f"lora_attn_scale: {guidance_cfg.lora_attn_scale}. \n") + + # for convenience + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.alphas = self.scheduler.alphas_cumprod.to(self.device) + self.text_embeddings = None + self.text_embedd_cond, self.text_embedd_uncond = None, None + self.text_embeddings_phi = None + self.t = None + + def set_lora_layers(self, unet): # set correct lora layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") \ + else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim + ).to(self.device) + unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(unet.attn_processors) + + unet.requires_grad_(False) + for param in lora_layers.parameters(): + param.requires_grad_(True) + return unet, lora_layers + + @torch.no_grad() + def encode_prompt(self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None): + # text conditional embed + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + prompt_embeds = self.text_encoder(text_inputs.input_ids.to(device))[0] + + if do_classifier_free_guidance: + if negative_prompt is None: + uncond_tokens = [""] + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + else: + uncond_tokens = negative_prompt + + # unconditional embed + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=prompt_embeds.shape[1], + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device))[0] + + concat_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + return concat_prompt_embeds, negative_prompt_embeds, prompt_embeds + + return prompt_embeds, None, None + + def sampling(self, + vae, + unet, + scheduler, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0): + + # 0. Default height and width to unet + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + height = height or unet.config.sample_size * vae_scale_factor + width = width or unet.config.sample_size * vae_scale_factor + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = 1 + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, _, _ = self.encode_prompt( + prompt, + self.device, + do_classifier_free_guidance, + negative_prompt, + ) + + # 4. Prepare timesteps + scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = unet.config.in_channels + latents = self.sd_pipeline.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + self.device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.sd_pipeline.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.sd_pipeline.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # update progress_bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.sd_pipeline.run_safety_checker(image, self.device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.sd_pipeline.image_processor.postprocess(image, output_type=output_type, + do_denormalize=do_denormalize) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def sample(self, + prompt, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil"): + return self.sampling(self.vae, self.unet, self.scheduler, + prompt=prompt, + height=height, width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + generator=generator, + output_type=output_type) + + def sample_lora(self, + prompt, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil"): + return self.sampling(self.vae_phi, self.unet_phi, self.scheduler, + prompt=prompt, + height=height, width=width, + num_inference_steps=num_inference_steps, + guidance_scale=self.guidance_scale_lora, + generator=generator, + cross_attention_kwargs=self.lora_cross_attention_kwargs, + output_type=output_type) + + def encode2latent(self, images): + images = (2 * images - 1).clamp(-1.0, 1.0) # images: [B, 3, H, W] + # encode images + latents = self.vae.encode(images).latent_dist.sample() + latents = self.vae.config.scaling_factor * latents + return latents + + def get_noise_map(self, noise_pred, guidance_scale=7.5, use_cfg=True): + if use_cfg: + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_map = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) + return noise_map + else: + return noise_pred + + def train_phi_model(self, + pred_rgb: torch.Tensor, + new_timesteps: bool = False, + as_latent: bool = False): + # interp to 512x512 to be fed into vae. + if as_latent: + latents = pred_rgb + else: + pred_rgb_ = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode2latent(pred_rgb_) + + # get phi particles + indices = torch.randperm(latents.size(0)) + latents_phi = latents[indices[:self.phi_n_particle]] + latents_phi = latents_phi.detach() + + # get timestep + if new_timesteps: + t = torch.randint(0, self.num_train_timesteps, (1,), device=self.device) + else: + t = self.t + + noise = torch.randn_like(latents_phi) + noisy_latents = self.scheduler.add_noise(latents_phi, noise, t) + + if self.scheduler.config.prediction_type == "epsilon": + target = noise + elif self.scheduler.config.prediction_type == "v_prediction": + target = self.scheduler.get_velocity(latents_phi, noise, t) + else: + raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}") + + # predict the noise residual and compute loss + noise_pred = self.unet_phi( + noisy_latents, t, + encoder_hidden_states=self.text_embeddings_phi, + cross_attention_kwargs=self.lora_cross_attention_kwargs, + ).sample + + return F.mse_loss(noise_pred, target, reduction="mean") + + def train_phi_model_refl(self, + pred_rgb: torch.Tensor, + weight: float = 1, + new_timesteps: bool = True): + # interp to 512x512 to be fed into vae. + pred_rgb_ = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode2latent(pred_rgb_) + + # get phi particles + indices = torch.randperm(latents.size(0)) + latents_phi = latents[indices[:self.phi_n_particle]] + latents_phi = latents_phi.detach() + + # get timestep + if new_timesteps: + t = torch.randint(0, self.num_train_timesteps, (1,), device=self.device) + else: + t = self.t + + noise = torch.randn_like(latents_phi) + noisy_latents = self.scheduler.add_noise(latents_phi, noise, t) + + if self.scheduler.config.prediction_type == "epsilon": + target = noise + elif self.scheduler.config.prediction_type == "v_prediction": + target = self.scheduler.get_velocity(latents_phi, noise, t) + else: + raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}") + + # predict the noise residual and compute loss + noise_pred = self.unet_phi( + noisy_latents, t, + encoder_hidden_states=self.text_embedd_cond, + cross_attention_kwargs=self.lora_cross_attention_kwargs, + ).sample + + rewards = torch.tensor(weight, dtype=torch.float32, device=self.device) + return rewards * F.mse_loss(noise_pred, target, reduction="mean") + + def schedule_timestep(self, step): + min_step = int(self.num_train_timesteps * self.t_range[0]) + max_step = int(self.num_train_timesteps * self.t_range[1]) + if self.t_schedule == 'randint': + t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device) + elif re.match(r"max_([\d.]+)_(\d+)", self.t_schedule): + # Anneal time schedule + # e.g: t_schedule == 'max_0.5_200' + # [0.02, 0.98] -> [0.02, 0.5] after 200 steps + tag, t_val, step_upd = str(self.t_schedule).split('_') + t_val, step_upd = float(t_val), int(step_upd) + if step >= step_upd: + max_step = int(self.num_train_timesteps * t_val) + t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device) + elif re.match(r"min_([\d.]+)_(\d+)", self.t_schedule): + # Anneal time schedule + # e.g: t_schedule == 'min_0.5_200' + # [0.02, 0.98] -> [0.5, 0.98] after 200 steps + tag, t_val, step_upd = str(self.t_schedule).split('_') + t_val, step_upd = float(t_val), int(step_upd) + if step >= step_upd: + min_step = int(self.num_train_timesteps * t_val) + t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device) + else: + raise NotImplementedError(f"{self.t_schedule} is not support.") + return t + + def set_text_embeddings(self, prompt, negative_prompt, do_classifier_free_guidance): + if self.text_embeddings is not None: + return + + # encode text prompt + text_embeddings, text_embeddings_uncond, text_embeddings_cond = \ + self.encode_prompt(prompt, self.device, do_classifier_free_guidance, negative_prompt=negative_prompt) + + # set pretrained model text embedding + text_embeddings_uncond, text_embeddings_cond = text_embeddings.chunk(2) + self.text_embedd_uncond, self.text_embedd_cond = text_embeddings_uncond, text_embeddings_cond + text_embeddings_unconds = text_embeddings_uncond.repeat_interleave(self.vsd_n_particle, dim=0) + text_embeddings_conds = text_embeddings_cond.repeat_interleave(self.vsd_n_particle, dim=0) + text_embeddings = torch.cat([text_embeddings_unconds, text_embeddings_conds]) + self.text_embeddings = text_embeddings + + # set phi model text embedding + self.text_embeddings_phi = text_embeddings_cond.repeat_interleave(self.phi_n_particle, dim=0) + + def x_augment(self, x: torch.Tensor, img_size: int = 512): + augment_compose = transforms.Compose([ + transforms.RandomPerspective(distortion_scale=0.5, p=0.7), + transforms.RandomCrop(size=(img_size, img_size), pad_if_needed=True, padding_mode='reflect') + ]) + return augment_compose(x) + + def variational_score_distillation(self, + pred_rgb: torch.Tensor, + step: int, + prompt: Union[List, str], + negative_prompt: Union[List, str] = None, + grad_scale: float = 1.0, + enhance_particle: bool = False, + im_size: int = 512, + as_latent: bool = False): + bz = pred_rgb.shape[0] + + # data enhancement for the input particles + pred_rgb = self.x_augment(pred_rgb, im_size) if enhance_particle else pred_rgb + + # interp to 512x512 to be fed into vae. + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + else: + pred_rgb_ = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + # latents = self.encode2latent(pred_rgb_) + latent_list = [self.encode2latent(pred_rgb_[i].unsqueeze(0)) for i in range(bz)] + latents = torch.cat(latent_list, dim=0) + latents = latents.to(self.device) + + # random sample n_particle_vsd particles from latents + latents_vsd = latents[torch.randperm(bz)[:self.vsd_n_particle]] + + # encode input prompt + do_classifier_free_guidance = True + self.set_text_embeddings(prompt, negative_prompt, do_classifier_free_guidance) + text_embeddings = self.text_embeddings + + # timestep a.k.a noise level + self.t = self.schedule_timestep(step) + + # predict the noise residual with unet, stop gradient + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents_vsd) + latents_noisy = self.scheduler.add_noise(latents_vsd, noise, self.t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) if do_classifier_free_guidance else latents_noisy + # pretrained noise prediction network + noise_pred_pretrain = self.unet( + latent_model_input, self.t, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs={'scale': 0.0} if self.phi_single else {} + ).sample + + # use conditional text embeddings in phi_model + _, text_embeddings_cond = text_embeddings.chunk(2) + # estimated noise prediction network + noise_pred_est = self.unet_phi( + latents_noisy, self.t, + encoder_hidden_states=text_embeddings_cond, + cross_attention_kwargs=self.lora_cross_attention_kwargs + ).sample + + # get pretrained score + noise_pred_pretrain = self.get_noise_map(noise_pred_pretrain, self.guidance_scale, use_cfg=True) + # get estimated score + noise_pred_est = self.get_noise_map(noise_pred_est, self.guidance_scale_lora, use_cfg=False) + + # w(t), sigma_t^2 + w = (1 - self.alphas[self.t]) + grad = grad_scale * w * (noise_pred_pretrain - noise_pred_est.detach()) + grad = torch.nan_to_num(grad) + + # grad clipping for stable training + if self.grad_clip_val is not None and self.grad_clip_val > 0: + grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) + + # re-parameterization trick: + # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad + target = (latents_vsd - grad).detach() + loss_vpsd = 0.5 * F.mse_loss(latents_vsd, target, reduction="sum") + + return loss_vpsd, grad.norm(), latents, self.t diff --git a/svgdreamer/painter/__init__.py b/svgdreamer/painter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56dada093c07f69ee8c55e634196ed84d0f8cbad --- /dev/null +++ b/svgdreamer/painter/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Description: + +from .painter_params import ( + Painter, PainterOptimizer, CosineWithWarmupLRLambda, RandomCoordInit, NaiveCoordInit, SparseCoordInit, get_sdf) +from .component_painter_params import CompPainter, CompPainterOptimizer +from .loss import xing_loss_fn +from .VPSD_pipeline import VectorizedParticleSDSPipeline +from .diffusion_pipeline import DiffusionPipeline diff --git a/svgdreamer/painter/component_painter_params.py b/svgdreamer/painter/component_painter_params.py new file mode 100644 index 0000000000000000000000000000000000000000..93f261a670f23808c0b43d10f93ed473a2918606 --- /dev/null +++ b/svgdreamer/painter/component_painter_params.py @@ -0,0 +1,610 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: content painter and optimizer +# Copyright (c) 2023, XiMing Xing. +# License: MIT License + +import copy +import math +import random +import pathlib +from typing import Dict, Tuple + +from shapely.geometry.polygon import Polygon +from omegaconf import DictConfig +import numpy as np +import pydiffvg +import torch +from torch.optim.lr_scheduler import LambdaLR + +from svgdreamer.painter import (SparseCoordInit, RandomCoordInit, NaiveCoordInit, get_sdf) +from svgdreamer.libs import get_optimizer + + +class CompPainter: + + def __init__( + self, + style: str, + target_img: torch.Tensor, + canvas_size: Tuple[int, int] = (600, 600), + num_segments: int = 4, + segment_init: str = 'circle', + radius: int = 20, + n_grid: int = 32, + stroke_width: int = 3, + device=None, + attn_init: bool = False, + attention_map: torch.Tensor = None, + attn_prob_tau: float = None, + ): + self.style = style + self.device = device + self.target_img = target_img + + # curve params + self.num_segments = num_segments + self.segment_init = segment_init + self.radius = radius + + self.canvas_width, self.canvas_height = canvas_size + """pixelart params""" + self.n_grid = n_grid # divide the canvas into n grids + self.pixel_per_grid = self.canvas_width // self.n_grid + """sketch params""" + self.stroke_width = stroke_width + """iconography params""" + self.color_ref = None + + self.shapes = [] # record all paths + self.shape_groups = [] + self.cur_shapes, self.cur_shape_groups = [], [] # record the current optimized path + self.point_vars = [] + self.color_vars = [] + self.width_vars = [] + + # init + self.attention_map = attention_map + self.attn_init = attn_init + self.attn_prob_tau = attn_prob_tau + self.select_inds = None + self.pos_init_method = None + + # background + self.para_bg = torch.tensor([1., 1., 1.], requires_grad=False, device=self.device) + # count the number of strokes + self.strokes_counter = 0 # counts the number of calls to "get_path" + + def attn_init_points(self, num_paths, mask=None): + attn_map = (self.attention_map - self.attention_map.min()) / \ + (self.attention_map.max() - self.attention_map.min()) + + attn_map_soft = np.copy(attn_map) + attn_map_soft[attn_map > 0] = softmax_t(attn_map[attn_map > 0], tau=self.attn_prob_tau) + # for visualizing + attn_thresh = np.copy(attn_map_soft) + # the probabilities associated with each entry in attn_map + attn_map_soft /= np.sum(attn_map_soft) + # select points + k = num_paths + + # select k points randomly + positions = np.where(mask == 1) + positions = np.stack(positions, axis=1) + np.random.shuffle(positions) + positions = positions[:k] + + # note: only use to visual + visual_coords = np.copy(positions) + + canvas_coords = np.copy(positions) + canvas_coords[:, [0, 1]] = canvas_coords[:, [1, 0]] + self.select_inds = canvas_coords + + # for visualizing + return attn_thresh, visual_coords + + def component_wise_path_init(self, pred, init_type: str = 'sparse'): + if init_type == 'random': + self.pos_init_method = RandomCoordInit(self.canvas_height, self.canvas_width) + + elif init_type == 'sparse': + assert self.target_img is not None # target_img as GT + # when initialized for the first time, the render result is None + if pred is None: + pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width) + # then pred is the render result + self.pos_init_method = SparseCoordInit(pred, self.target_img) + + elif init_type == 'naive': + assert self.target_img is not None # target_img as GT + if pred is None: + pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width) + self.pos_init_method = NaiveCoordInit(pred, self.target_img) + + else: + raise NotImplementedError(f"'{init_type}' is not support.") + + def init_image(self, num_paths=0): + self.cur_shapes, self.cur_shape_groups = [], [] + + if self.style == 'pixelart': # update path definition + num_paths = self.n_grid + + for i in range(num_paths): + if self.style == 'iconography': + path = self.get_path() + self.shapes.append(path) + self.cur_shapes.append(path) + + wref, href = self.color_ref + wref = max(0, min(int(wref), self.canvas_width - 1)) + href = max(0, min(int(href), self.canvas_height - 1)) + fill_color_init = list(self.target_img[0, :, href, wref]) + [1.] + fill_color_init = torch.FloatTensor(fill_color_init) + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([len(self.shapes) - 1]), + fill_color=fill_color_init, + stroke_color=None + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + + elif self.style == 'pixelart': + fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) + fill_color_init[-1] = 1.0 + + for j in range(num_paths): + path = self.get_path(coord=[i, j]) + self.shapes.append(path) + self.cur_shapes.append(path) + + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.LongTensor([i * num_paths + j]), + fill_color=fill_color_init, + stroke_color=None, + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + + elif self.style == 'sketch': + path = self.get_path() + self.shapes.append(path) + self.cur_shapes.append(path) + + stroke_color_init = torch.tensor([0.0, 0.0, 0.0, 1.0]) + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([len(self.shapes) - 1]), + fill_color=None, + stroke_color=stroke_color_init + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + + elif self.style == 'painting': + path = self.get_path() + self.shapes.append(path) + self.cur_shapes.append(path) + + wref, href = self.color_ref + wref = max(0, min(int(wref), self.canvas_width - 1)) + href = max(0, min(int(href), self.canvas_height - 1)) + stroke_color_init = list(self.target_img[0, :, href, wref]) + [1.] + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([len(self.shapes) - 1]), + fill_color=None, + stroke_color=stroke_color_init + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + + img = self.render_warp() + img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4]) + img = img.unsqueeze(0) # convert img from HWC to NCHW + img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW + return img + + def get_image(self, step: int = 0): + img = self.render_warp(step) + img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4]) + img = img.unsqueeze(0) # convert img from HWC to NCHW + img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW + return img + + def get_path(self, coord=None): + num_segments = self.num_segments + + points = [] + if self.style == 'iconography': + num_control_points = [2] * num_segments + # init segment + if self.segment_init == 'circle': + radius = self.radius if self.radius is not None else np.random.uniform(0.5, 1) + + if self.attn_init: + center = self.select_inds[self.strokes_counter] # shape: (2,) + else: + center = (random.random(), random.random()) \ + if self.pos_init_method is None else self.pos_init_method() + + bias = center + self.color_ref = copy.deepcopy(bias) + + points = get_circle_coordinates(center, radius, k=num_segments * 3) + points = torch.FloatTensor(points) + else: + if self.attn_init: + p0 = self.select_inds[self.strokes_counter] + else: + p0 = self.pos_init_method() + + self.color_ref = copy.deepcopy(p0) + points.append(p0) + for j in range(num_segments): + radius = self.radius + p1 = (p0[0] + radius * np.random.uniform(-0.5, 0.5), + p0[1] + radius * np.random.uniform(-0.5, 0.5)) + p2 = (p1[0] + radius * np.random.uniform(-0.5, 0.5), + p1[1] + radius * np.random.uniform(-0.5, 0.5)) + p3 = (p2[0] + radius * np.random.uniform(-0.5, 0.5), + p2[1] + radius * np.random.uniform(-0.5, 0.5)) + points.append(p1) + points.append(p2) + if j < num_segments - 1: + points.append(p3) + p0 = p3 + points = torch.FloatTensor(points) + + path = pydiffvg.Path(num_control_points=torch.LongTensor(num_control_points), + points=points, + stroke_width=torch.tensor(0.0), + is_closed=True) + elif self.style in ['sketch', 'painting', 'ink']: + num_control_points = torch.zeros(num_segments, dtype=torch.long) + 2 + points = [] + + if self.attn_init: + p0 = self.select_inds[self.strokes_counter] + else: + p0 = (random.random(), random.random()) \ + if self.pos_init_method is None else self.pos_init_method() + + self.color_ref = copy.deepcopy(p0) + + points.append(p0) + for j in range(num_segments): + radius = 0.1 + p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) + p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) + p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) + points.append(p1) + points.append(p2) + points.append(p3) + p0 = p3 + points = torch.tensor(points).to(self.device) + + if not self.attn_init: + points[:, 0] *= self.canvas_width + points[:, 1] *= self.canvas_height + + path = pydiffvg.Path(num_control_points=torch.LongTensor(num_control_points), + points=points, + stroke_width=torch.tensor(self.stroke_width), + is_closed=False) + elif self.style == 'pixelart': + x = coord[0] * self.pixel_per_grid + y = coord[1] * self.pixel_per_grid + points = torch.FloatTensor([ + [x, y], + [x + self.pixel_per_grid, y], + [x + self.pixel_per_grid, y + self.pixel_per_grid], + [x, y + self.pixel_per_grid] + ]).to(self.device) + path = pydiffvg.Polygon(points=points, + stroke_width=torch.tensor(0.0), + is_closed=True) + + self.strokes_counter += 1 + return path + + def clip_curve_shape(self): + for group in self.shape_groups: + group.fill_color.data.clamp_(0.0, 1.0) + + def reinitialize_paths(self, + reinit_path: bool = False, + opacity_threshold: float = None, + area_threshold: float = None, + fpath: pathlib.Path = None): + """ + reinitialize paths, also known as 'Reinitializing paths' in VectorFusion paper. + + Args: + reinit_path: whether to reinitialize paths or not. + opacity_threshold: Threshold of opacity. + area_threshold: Threshold of the closed polygon area. + fpath: The path to save the reinitialized SVG. + """ + if self.style == 'iconography' and reinit_path: + # re-init by opacity_threshold + select_path_ids_by_opc = [] + if opacity_threshold != 0 and opacity_threshold is not None: + def get_keys_below_threshold(my_dict, threshold): + keys_below_threshold = [key for key, value in my_dict.items() if value < threshold] + return keys_below_threshold + + opacity_record_ = {group.shape_ids.item(): group.fill_color.data[-1].item() + for group in self.cur_shape_groups} + # print("-> opacity_record: ", opacity_record_) + print("-> opacity_record: ", [f"{k}: {v:.3f}" for k, v in opacity_record_.items()]) + select_path_ids_by_opc = get_keys_below_threshold(opacity_record_, opacity_threshold) + print("select_path_ids_by_opc: ", select_path_ids_by_opc) + + # remove path by area_threshold + select_path_ids_by_area = [] + if area_threshold != 0 and area_threshold is not None: + area_records = [Polygon(shape.points.detach().numpy()).area for shape in self.cur_shapes] + # print("-> area_records: ", area_records) + print("-> area_records: ", ['%.2f' % i for i in area_records]) + for i, shape in enumerate(self.cur_shapes): + if Polygon(shape.points.detach().numpy()).area < area_threshold: + select_path_ids_by_area.append(shape.id) + print("select_path_ids_by_area: ", select_path_ids_by_area) + + # re-init paths + reinit_union = list(set(select_path_ids_by_opc + select_path_ids_by_area)) + if len(reinit_union) > 0: + for i, path in enumerate(self.cur_shapes): + if path.id in reinit_union: + self.cur_shapes[i] = self.get_path() + for i, group in enumerate(self.cur_shape_groups): + shp_ids = group.shape_ids.cpu().numpy().tolist() + if set(shp_ids).issubset(reinit_union): + fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) + fill_color_init[-1] = np.random.uniform(0.7, 1) + stroke_color_init = torch.FloatTensor(np.random.uniform(size=[4])) + self.cur_shape_groups[i] = pydiffvg.ShapeGroup( + shape_ids=torch.tensor(list(shp_ids)), + fill_color=fill_color_init, + stroke_color=stroke_color_init) + # save reinit svg + self.save_svg(fpath) + + print("-" * 40) + + def render_warp(self, seed=0): + scene_args = pydiffvg.RenderFunction.serialize_scene( + self.canvas_width, self.canvas_height, self.shapes, self.shape_groups + ) + _render = pydiffvg.RenderFunction.apply + img = _render(self.canvas_width, # width + self.canvas_height, # height + 2, # num_samples_x + 2, # num_samples_y + seed, # seed + None, + *scene_args) + return img + + def calc_distance_weight(self, loss_weight_keep): + shapes_forsdf = copy.deepcopy(self.cur_shapes) + shape_groups_forsdf = copy.deepcopy(self.cur_shape_groups) + for si in shapes_forsdf: + si.stroke_width = torch.FloatTensor([0]).to(self.device) + for sg_idx, sgi in enumerate(shape_groups_forsdf): + sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(self.device) + sgi.shape_ids = torch.LongTensor([sg_idx]).to(self.device) + + sargs_forsdf = pydiffvg.RenderFunction.serialize_scene( + self.canvas_width, self.canvas_height, shapes_forsdf, shape_groups_forsdf + ) + _render = pydiffvg.RenderFunction.apply + with torch.no_grad(): + im_forsdf = _render(self.canvas_width, # width + self.canvas_height, # height + 2, # num_samples_x + 2, # num_samples_y + 0, # seed + None, + *sargs_forsdf) + + # use alpha channel is a trick to get 0-1 image + im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy() + loss_weight = get_sdf(im_forsdf, normalize='to1') + loss_weight += loss_weight_keep + loss_weight = np.clip(loss_weight, 0, 1) + loss_weight = torch.FloatTensor(loss_weight).to(self.device) + return loss_weight + + def set_points_parameters(self, id_delta=0): + self.point_vars = [] + for i, path in enumerate(self.cur_shapes): + path.id = i + id_delta # set point id + path.points.requires_grad = True + self.point_vars.append(path.points) + + def get_point_params(self): + return self.point_vars + + def set_color_parameters(self): + self.color_vars = [] + for i, group in enumerate(self.cur_shape_groups): + if group.fill_color is not None: + group.fill_color.requires_grad = True + self.color_vars.append(group.fill_color) + if group.stroke_color is not None: + group.stroke_color.requires_grad = True + self.color_vars.append(group.stroke_color) + + def get_color_params(self): + return self.color_vars + + def set_width_parameters(self): + # stroke`s width optimization + self.width_vars = [] + for i, path in enumerate(self.shapes): + path.stroke_width.requires_grad = True + self.width_vars.append(path.stroke_width) + + def get_width_params(self): + return self.width_vars + + def save_svg(self, fpath): + pydiffvg.save_svg(f'{fpath}', + self.canvas_width, + self.canvas_height, + self.shapes, + self.shape_groups) + + def load_svg(self, path_svg): + canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg) + return canvas_width, canvas_height, shapes, shape_groups + + +def softmax_t(x, tau=0.2): + e_x = np.exp(x / tau) + return e_x / e_x.sum() + + +def get_circle_coordinates(center, radius, k): + coordinates = [] + cx, cy = center + angle = 2 * math.pi / k + + for i in range(k): + theta = i * angle # cur angle + x = cx + radius * math.cos(theta) # x + y = cy + radius * math.sin(theta) # y + coordinates.append((x, y)) + + return coordinates + + +class LinearDecayLRLambda: + + def __init__(self, init_lr, keep_ratio, decay_every, decay_ratio): + self.init_lr = init_lr + self.keep_ratio = keep_ratio + self.decay_every = decay_every + self.decay_ratio = decay_ratio + + def __call__(self, n): + if n < self.keep_ratio * self.decay_every: + return self.init_lr + + decay_time = n // self.decay_every + decay_step = n % self.decay_every + lr_s = self.decay_ratio ** decay_time + lr_e = self.decay_ratio ** (decay_time + 1) + r = decay_step / self.decay_every + lr = lr_s * (1 - r) + lr_e * r + return lr + + +class CompPainterOptimizer: + + def __init__(self, + renderer: CompPainter, + style: str, + num_iter: int, + lr_config: DictConfig, + optim_bg: bool = False): + self.renderer = renderer + self.style = style + self.num_iter = num_iter + self.lr_config = lr_config + schedule_cfg = self.lr_config.schedule + self.optim_bg = optim_bg + + if style == 'iconography': + self.optim_point, self.optim_color, self.optim_width = True, True, False + self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio, + self.num_iter, schedule_cfg.decay_ratio) + if style == 'pixelart': + self.optim_point, self.optim_color, self.optim_width = False, True, False + self.point_lr_lambda = None + if style == 'sketch': + self.optim_point, self.optim_color, self.optim_width = True, False, False + self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio, + self.num_iter, schedule_cfg.decay_ratio) + if style == 'ink': + self.optim_point, self.optim_color, self.optim_width = True, False, True + self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio, + self.num_iter, schedule_cfg.decay_ratio) + if style == 'painting': + self.optim_point, self.optim_color, self.optim_width = True, True, True + self.point_lr_lambda = LinearDecayLRLambda(self.lr_config.point, schedule_cfg.keep_ratio, + self.num_iter, schedule_cfg.decay_ratio) + + self.point_optimizer = None + self.color_optimizer = None + self.width_optimizer = None + self.bg_optimizer = None + + self.point_scheduler = None + + def init_optimizers(self, pid_delta=0): + optim_cfg = self.lr_config.optim + optim_name = optim_cfg.name + + params = {} + if self.optim_point: + self.renderer.set_points_parameters(pid_delta) + params['point'] = self.renderer.get_point_params() + + if len(params['point']) > 0: + self.point_optimizer = get_optimizer(optim_name, params['point'], self.lr_config.point, optim_cfg) + if self.point_lr_lambda is not None: + self.point_scheduler = LambdaLR(self.point_optimizer, lr_lambda=self.point_lr_lambda, last_epoch=-1) + + if self.optim_color: + self.renderer.set_color_parameters() + params['color'] = self.renderer.get_color_params() + if len(params['color']) > 0: + self.color_optimizer = get_optimizer(optim_name, params['color'], self.lr_config.color, optim_cfg) + + if self.optim_width: + self.renderer.set_width_parameters() + params['width'] = self.renderer.get_width_params() + if len(params['width']) > 0: + self.width_optimizer = get_optimizer(optim_name, params['width'], self.lr_config.width, optim_cfg) + + if self.optim_bg: + self.renderer.para_bg.requires_grad = True + self.bg_optimizer = get_optimizer(optim_name, self.renderer.para_bg, self.lr_config.bg, optim_cfg) + + def update_lr(self): + if self.point_scheduler is not None: + self.point_scheduler.step() + + def zero_grad_(self): + if self.point_optimizer is not None: + self.point_optimizer.zero_grad() + if self.color_optimizer is not None: + self.color_optimizer.zero_grad() + if self.width_optimizer is not None: + self.width_optimizer.zero_grad() + if self.bg_optimizer is not None: + self.bg_optimizer.zero_grad() + + def step_(self): + if self.point_optimizer is not None: + self.point_optimizer.step() + if self.color_optimizer is not None: + self.color_optimizer.step() + if self.width_optimizer is not None: + self.width_optimizer.step() + if self.bg_optimizer is not None: + self.bg_optimizer.step() + + def get_lr(self) -> Dict: + lr = {} + if self.point_optimizer is not None: + lr['pnt'] = self.point_optimizer.param_groups[0]['lr'] + if self.color_optimizer is not None: + lr['clr'] = self.color_optimizer.param_groups[0]['lr'] + if self.width_optimizer is not None: + lr['wd'] = self.width_optimizer.param_groups[0]['lr'] + if self.bg_optimizer is not None: + lr['bg'] = self.bg_optimizer.param_groups[0]['lr'] + return lr diff --git a/svgdreamer/painter/diffusion_pipeline.py b/svgdreamer/painter/diffusion_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..50febe91253da976682f2a12ef20acb61fa27304 --- /dev/null +++ b/svgdreamer/painter/diffusion_pipeline.py @@ -0,0 +1,402 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: +import PIL +from PIL import Image +from typing import Any, List, Optional, Union, Dict +from omegaconf import DictConfig + +import numpy as np +import torch +from diffusers import StableDiffusionPipeline +from diffusers import DDIMScheduler +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + rescale_noise_cfg, StableDiffusionPipelineOutput) + +from svgdreamer.diffusers_warp import init_StableDiffusion_pipeline +from svgdreamer.token2attn.attn_control import AttentionStore +from svgdreamer.token2attn.ptp_utils import text_under_image, view_images + + +class DiffusionPipeline(torch.nn.Module): + + def __init__(self, model_cfg: DictConfig, diffuser_cfg: DictConfig, device: torch.device): + super().__init__() + self.device = device + + pipe_kwargs = { + "device": self.device, + "torch_dtype": torch.float32, + "local_files_only": not diffuser_cfg.download, + "force_download": diffuser_cfg.force_download, + "resume_download": diffuser_cfg.resume_download, + "ldm_speed_up": model_cfg.ldm_speed_up, + "enable_xformers": model_cfg.enable_xformers, + "gradient_checkpoint": model_cfg.gradient_checkpoint, + "cpu_offload": model_cfg.cpu_offload, + "vae_slicing": False + } + + # load pretrained model + self.sd_pipeline = init_StableDiffusion_pipeline( + model_cfg.model_id, + custom_pipeline=StableDiffusionPipeline, + custom_scheduler=DDIMScheduler, + **pipe_kwargs + ) + # disable grads + self.sd_pipeline.vae.requires_grad_(False) + self.sd_pipeline.text_encoder.requires_grad_(False) + self.sd_pipeline.unet.requires_grad_(False) + # set components + self.vae = self.sd_pipeline.vae + self.unet = self.sd_pipeline.unet + self.scheduler = self.sd_pipeline.scheduler + self.tokenizer = self.sd_pipeline.tokenizer + self.text_encoder = self.sd_pipeline.text_encoder + + @torch.no_grad() + def encode_prompt(self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None): + # text conditional embed + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + prompt_embeds = self.text_encoder(text_inputs.input_ids.to(device))[0] + + if do_classifier_free_guidance: + if negative_prompt is None: + uncond_tokens = [""] + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + else: + uncond_tokens = negative_prompt + + # unconditional embed + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=prompt_embeds.shape[1], + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device))[0] + + concat_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + return concat_prompt_embeds, negative_prompt_embeds, prompt_embeds + + return prompt_embeds, None, None + + def register_attention_control(self, controller): + attn_procs = {} + cross_att_count = 0 + for name in self.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.unet.config.block_out_channels[-1] + place_in_unet = "mid" + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] + place_in_unet = "up" + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.unet.config.block_out_channels[block_id] + place_in_unet = "down" + else: + continue + cross_att_count += 1 + attn_procs[name] = P2PCrossAttnProcessor( + controller=controller, place_in_unet=place_in_unet + ) + + self.unet.set_attn_processor(attn_procs) + controller.num_att_layers = cross_att_count + + @staticmethod + def aggregate_attention(prompts, + attention_store: AttentionStore, + res: int, + from_where: List[str], + is_cross: bool, + select: int): + if isinstance(prompts, str): + prompts = [prompts] + assert isinstance(prompts, list) + + out = [] + attention_maps = attention_store.get_average_attention() + num_pixels = res ** 2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + out = torch.cat(out, dim=0) + out = out.sum(0) / out.shape[0] + return out.cpu() + + def get_cross_attention(self, + prompts, + attention_store: AttentionStore, + res: int, + from_where: List[str], + select: int = 0, + save_path=None): + tokens = self.tokenizer.encode(prompts[select]) + decoder = self.tokenizer.decode + # shape: [res ** 2, res ** 2, seq_len] + attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, True, select) + + images_text = [] + images = [] + for i in range(len(tokens)): + image = attention_maps[:, :, i] + image = 255 * image / image.max() + image = image.unsqueeze(-1).expand(*image.shape, 3) + image = image.numpy().astype(np.uint8) + image = np.array(Image.fromarray(image).resize((256, 256))) + images.append(np.copy(image)) + image = text_under_image(image, decoder(int(tokens[i]))) + images_text.append(image) + image_array = np.stack(images_text, axis=0) + view_images(image_array, save_image=True, fp=save_path) + + return attention_maps, tokens + + def get_self_attention_comp(self, + prompts, + attention_store: AttentionStore, + res: int, + from_where: List[str], + img_size: int = 224, + max_com=10, + select: int = 0, + save_path=None): + attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, False, select) + attention_maps = attention_maps.numpy().reshape((res ** 2, res ** 2)) + # shape: [res ** 2, res ** 2] + u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) + print(f"self-attention maps: {attention_maps.shape}, " + f"u: {u.shape}, " + f"s: {s.shape}, " + f"vh: {vh.shape}") + + images = [] + vh_returns = [] + for i in range(max_com): + image = vh[i].reshape(res, res) + image = (image - image.min()) / (image.max() - image.min()) + image = 255 * image + + ret_ = Image.fromarray(image).resize((img_size, img_size), resample=PIL.Image.Resampling.BILINEAR) + vh_returns.append(np.array(ret_)) + + image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) + image = Image.fromarray(image).resize((256, 256)) + image = np.array(image) + images.append(image) + image_array = np.stack(images, axis=0) + view_images(image_array, num_rows=max_com // 10, offset_ratio=0, + save_image=True, fp=save_path / "self-attn-vh.png") + + return attention_maps, (u, s, vh), np.stack(vh_returns, axis=0) + + def sampling(self, + vae, + unet, + scheduler, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + controller: AttentionStore = None, # feed attention_store as control of ptp + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0): + + # add attention controller + self.register_attention_control(controller) + + # 0. Default height and width to unet + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + height = height or unet.config.sample_size * vae_scale_factor + width = width or unet.config.sample_size * vae_scale_factor + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = 1 + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, _, _ = self.encode_prompt( + prompt, + self.device, + do_classifier_free_guidance, + negative_prompt, + ) + + # 4. Prepare timesteps + scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = unet.config.in_channels + latents = self.sd_pipeline.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + self.device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.sd_pipeline.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.sd_pipeline.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # controller callback + latents = controller.step_callback(latents) + + # update progress_bar + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.sd_pipeline.run_safety_checker(image, self.device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.sd_pipeline.image_processor.postprocess(image, output_type=output_type, + do_denormalize=do_denormalize) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def sample(self, + prompt, + height: Optional[int] = None, + width: Optional[int] = None, + controller: AttentionStore = None, # feed attention_store as control of ptp + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil"): + return self.sampling(self.vae, self.unet, self.scheduler, + prompt=prompt, + height=height, width=width, + controller=controller, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + generator=generator, + output_type=output_type) + + def encode2latent(self, images): + images = (2 * images - 1).clamp(-1.0, 1.0) # images: [B, 3, H, W] + # encode images + latents = self.vae.encode(images).latent_dist.sample() + latents = self.vae.config.scaling_factor * latents + return latents + + +class P2PCrossAttnProcessor: + + def __init__(self, controller, place_in_unet): + super().__init__() + self.controller = controller + self.place_in_unet = place_in_unet + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size) + + query = attn.to_q(hidden_states) + + is_cross = encoder_hidden_states is not None + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + + # one line change + self.controller(attention_probs, is_cross, self.place_in_unet) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states diff --git a/svgdreamer/painter/loss.py b/svgdreamer/painter/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c68ad7dab8411c20ee73a18ded53924699cd6988 --- /dev/null +++ b/svgdreamer/painter/loss.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: +import torch + + +def channel_saturation_penalty_loss(x: torch.Tensor): + assert x.shape[1] == 3 + r_channel = x[:, 0, :, :] + g_channel = x[:, 1, :, :] + b_channel = x[:, 2, :, :] + channel_accumulate = torch.pow(r_channel, 2) + torch.pow(g_channel, 2) + torch.pow(b_channel, 2) + return channel_accumulate.mean() / 3 + + +def area(a, b, c): + return (c[1] - a[1]) * (b[0] - a[0]) - (b[1] - a[1]) * (c[0] - a[0]) + + +def triangle_area(A, B, C): + out = (C - A).flip([-1]) * (B - A) + out = out[..., 1] - out[..., 0] + return out + + +def compute_sine_theta(s1, s2): # s1 and s2 aret two segments to be uswed + # s1, s2 (2, 2) + v1 = s1[1, :] - s1[0, :] + v2 = s2[1, :] - s2[0, :] + # print(v1, v2) + sine_theta = (v1[0] * v2[1] - v1[1] * v2[0]) / (torch.norm(v1) * torch.norm(v2)) + return sine_theta + + +def xing_loss_fn(x_list, scale=1e-3): # x[npoints, 2] + loss = 0. + # print(f"points_len: {len(x_list)}") + for x in x_list: + # print(f"x: {x}") + seg_loss = 0. + N = x.size()[0] + assert N % 3 == 0, f'The segment number ({N}) is not correct!' + x = torch.cat([x, x[0, :].unsqueeze(0)], dim=0) # (N+1,2) + segments = torch.cat([x[:-1, :].unsqueeze(1), x[1:, :].unsqueeze(1)], dim=1) # (N, start/end, 2) + segment_num = int(N / 3) + for i in range(segment_num): + cs1 = segments[i * 3, :, :] # start control segs + cs2 = segments[i * 3 + 1, :, :] # middle control segs + cs3 = segments[i * 3 + 2, :, :] # end control segs + # print('the direction of the vectors:') + # print(compute_sine_theta(cs1, cs2)) + direct = (compute_sine_theta(cs1, cs2) >= 0).float() + opst = 1 - direct # another direction + sina = compute_sine_theta(cs1, cs3) # the angle between cs1 and cs3 + seg_loss += direct * torch.relu(- sina) + opst * torch.relu(sina) + # print(direct, opst, sina) + seg_loss /= segment_num + + templ = seg_loss + loss += templ * scale # area_loss * scale + + return loss / (len(x_list)) diff --git a/svgdreamer/painter/painter_params.py b/svgdreamer/painter/painter_params.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8c53a104af5544c7e8144267fcbc3d2bf8147b --- /dev/null +++ b/svgdreamer/painter/painter_params.py @@ -0,0 +1,811 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: +import math +import copy +import random +import pathlib +from typing import Dict + +from shapely.geometry.polygon import Polygon +import omegaconf +import cv2 +import numpy as np +import pydiffvg +import torch +from torch.optim.lr_scheduler import LambdaLR + +from svgdreamer.diffvg_warp import DiffVGState +from svgdreamer.libs import get_optimizer + + +class Painter(DiffVGState): + + def __init__( + self, + diffvg_cfg: omegaconf.DictConfig, + style: str, + num_segments: int, + segment_init: str, + radius: int = 20, + canvas_size: int = 600, + n_grid: int = 32, + trainable_bg: bool = False, + stroke_width: int = 3, + path_svg=None, + device=None, + ): + super().__init__(device, print_timing=diffvg_cfg.print_timing, + canvas_width=canvas_size, canvas_height=canvas_size) + + self.style = style + + self.num_segments = num_segments + self.segment_init = segment_init + self.radius = radius + + """pixelart params""" + self.n_grid = n_grid # divide the canvas into n grids + self.pixel_per_grid = self.canvas_width // self.n_grid + """sketch params""" + self.stroke_width = stroke_width + """iconography params""" + self.color_ref = None + + self.path_svg = path_svg + self.optimize_flag = [] + + self.strokes_counter = 0 # counts the number of calls to "get_path" + + # Background color + self.para_bg = torch.tensor([1., 1., 1.], requires_grad=trainable_bg, device=self.device) + + self.target_img = None + self.pos_init_method = None + + def component_wise_path_init(self, gt, pred, init_type: str = 'sparse'): + # set target image + self.target_img = gt + + if init_type == 'random': + self.pos_init_method = RandomCoordInit(self.canvas_height, self.canvas_width) + elif init_type == 'sparse': + # when initialized for the first time, the render result is None + if pred is None: + pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width) + # then pred is the render result + self.pos_init_method = SparseCoordInit(pred, gt) + elif init_type == 'naive': + if pred is None: + pred = self.para_bg.view(1, -1, 1, 1).repeat(1, 1, self.canvas_height, self.canvas_width) + self.pos_init_method = NaiveCoordInit(pred, gt) + else: + raise NotImplementedError(f"'{init_type}' is not support.") + + def init_image(self, stage=0, num_paths=0): + self.cur_shapes, self.cur_shape_groups = [], [] + + # or init svg by PyDiffVG + if self.style in ['pixelart', 'low-poly']: # update path definition + num_paths = self.n_grid + + if stage > 0: + # Noting: if multi stages training than add new strokes on existing ones + # don't optimize on previous strokes + self.optimize_flag = [False for i in range(len(self.shapes))] + for i in range(num_paths): + if self.style == 'iconography': + path = self.get_path() + self.shapes.append(path) + self.cur_shapes.append(path) + + fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) + fill_color_init[-1] = 1.0 + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([self.strokes_counter - 1]), + fill_color=fill_color_init, + stroke_color=None + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + self.optimize_flag.append(True) + + elif self.style in ['pixelart', 'low-poly']: + for j in range(num_paths): + path = self.get_path(coord=[i, j]) + self.shapes.append(path) + self.cur_shapes.append(path) + + fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) + fill_color_init[-1] = 1.0 + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.LongTensor([i * num_paths + j]), + fill_color=fill_color_init, + stroke_color=None, + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + self.optimize_flag.append(True) + + elif self.style in ['ink', 'sketch']: + path = self.get_path() + self.shapes.append(path) + self.cur_shapes.append(path) + + stroke_color_init = [0.0, 0.0, 0.0] + [random.random()] + stroke_color_init = torch.FloatTensor(stroke_color_init) + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([len(self.shapes) - 1]), + fill_color=None, + stroke_color=stroke_color_init + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + + elif self.style == 'painting': + path = self.get_path() + self.shapes.append(path) + self.cur_shapes.append(path) + + wref, href = self.color_ref + wref = max(0, min(int(wref), self.canvas_width - 1)) + href = max(0, min(int(href), self.canvas_height - 1)) + stroke_color_init = list(self.target_img[0, :, href, wref]) + [1.] + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([len(self.shapes) - 1]), + fill_color=None, + stroke_color=torch.FloatTensor(stroke_color_init) + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + else: + num_paths_exists = 0 + if self.path_svg is not None and pathlib.Path(self.path_svg).exists(): + print(f"-> init svg from `{self.path_svg}` ...") + + self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = self.load_svg(self.path_svg) + # if you want to add more strokes to existing ones and optimize on all of them + num_paths_exists = len(self.shapes) + + self.cur_shapes = self.shapes + self.cur_shape_groups = self.shape_groups + + for i in range(num_paths_exists, num_paths): + if self.style == 'iconography': + path = self.get_path() + self.shapes.append(path) + self.cur_shapes.append(path) + + wref, href = self.color_ref + wref = max(0, min(int(wref), self.canvas_width - 1)) + href = max(0, min(int(href), self.canvas_height - 1)) + fill_color_init = list(self.target_img[0, :, href, wref]) + [1.] + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([self.strokes_counter - 1]), + fill_color=torch.FloatTensor(fill_color_init), + stroke_color=None + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + + elif self.style in ['pixelart', 'low-poly']: + for j in range(num_paths): + path = self.get_path(coord=[i, j]) + self.shapes.append(path) + self.cur_shapes.append(path) + + fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) + fill_color_init[-1] = 1.0 + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.LongTensor([i * num_paths + j]), + fill_color=fill_color_init, + stroke_color=None, + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + + elif self.style in ['sketch', 'ink']: + path = self.get_path() + self.shapes.append(path) + self.cur_shapes.append(path) + + stroke_color_init = [0.0, 0.0, 0.0] + [random.random()] + stroke_color_init = torch.FloatTensor(stroke_color_init) + + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([len(self.shapes) - 1]), + fill_color=None, + stroke_color=stroke_color_init + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + + elif self.style in ['painting']: + path = self.get_path() + self.shapes.append(path) + self.cur_shapes.append(path) + + if self.color_ref is None: + stroke_color_val = np.random.uniform(size=[4]) + stroke_color_val[-1] = 1.0 + stroke_color_init = torch.FloatTensor(stroke_color_val) + else: + wref, href = self.color_ref + wref = max(0, min(int(wref), self.canvas_width - 1)) + href = max(0, min(int(href), self.canvas_height - 1)) + stroke_color_init = list(self.target_img[0, :, href, wref]) + [1.] + stroke_color_init = torch.FloatTensor(stroke_color_init) + + path_group = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([len(self.shapes) - 1]), + fill_color=None, + stroke_color=stroke_color_init + ) + self.shape_groups.append(path_group) + self.cur_shape_groups.append(path_group) + + self.optimize_flag = [True for i in range(len(self.shapes))] + + img = self.get_image() + return img + + def get_image(self, step: int = 0): + img = self.render_warp(step) + img = img[:, :, 3:4] * img[:, :, :3] + self.para_bg * (1 - img[:, :, 3:4]) + img = img.unsqueeze(0) # convert img from HWC to NCHW + img = img.permute(0, 3, 1, 2).to(self.device) # NHWC -> NCHW + return img + + def get_path(self, coord=None): + num_segments = self.num_segments + + points = [] + if self.style == 'iconography': + # init segment + if self.segment_init == 'circle': + num_control_points = [2] * num_segments + radius = self.radius if self.radius is not None else np.random.uniform(0.5, 1) + if self.pos_init_method is not None: + center = self.pos_init_method() + else: + center = (random.random(), random.random()) + bias = center + self.color_ref = copy.deepcopy(bias) + + avg_degree = 360 / (num_segments * 3) + for i in range(0, num_segments * 3): + point = ( + np.cos(np.deg2rad(i * avg_degree)), np.sin(np.deg2rad(i * avg_degree)) + ) + points.append(point) + + points = torch.FloatTensor(points) * radius + torch.FloatTensor(bias).unsqueeze(dim=0) + elif self.segment_init == 'random': + num_control_points = [2] * num_segments + p0 = self.pos_init_method() + self.color_ref = copy.deepcopy(p0) + points.append(p0) + + for j in range(num_segments): + radius = self.radius + p1 = (p0[0] + radius * np.random.uniform(-0.5, 0.5), + p0[1] + radius * np.random.uniform(-0.5, 0.5)) + p2 = (p1[0] + radius * np.random.uniform(-0.5, 0.5), + p1[1] + radius * np.random.uniform(-0.5, 0.5)) + p3 = (p2[0] + radius * np.random.uniform(-0.5, 0.5), + p2[1] + radius * np.random.uniform(-0.5, 0.5)) + points.append(p1) + points.append(p2) + if j < num_segments - 1: + points.append(p3) + p0 = p3 + points = torch.FloatTensor(points) + else: + raise NotImplementedError(f"{self.segment_init} is not exists.") + + path = pydiffvg.Path( + num_control_points=torch.LongTensor(num_control_points), + points=points, + stroke_width=torch.tensor(0.0), + is_closed=True + ) + + elif self.style in ['sketch', 'painting', 'ink']: + num_control_points = torch.zeros(num_segments, dtype=torch.long) + 2 + points = [] + p0 = [random.random(), random.random()] + points.append(p0) + + # select color by first point coordinate + color_ref = copy.deepcopy(p0) + color_ref[0] *= self.canvas_width + color_ref[1] *= self.canvas_height + self.color_ref = color_ref + + for j in range(num_segments): + radius = 0.1 + p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5)) + p2 = (p1[0] + radius * (random.random() - 0.5), p1[1] + radius * (random.random() - 0.5)) + p3 = (p2[0] + radius * (random.random() - 0.5), p2[1] + radius * (random.random() - 0.5)) + points.append(p1) + points.append(p2) + points.append(p3) + p0 = p3 + points = torch.tensor(points).to(self.device) + points[:, 0] *= self.canvas_width + points[:, 1] *= self.canvas_height + + path = pydiffvg.Path(num_control_points=torch.LongTensor(num_control_points), + points=points, + stroke_width=torch.tensor(float(self.stroke_width)), + is_closed=False) + + elif self.style in ['pixelart', 'low-poly']: + x = coord[0] * self.pixel_per_grid + y = coord[1] * self.pixel_per_grid + points = torch.FloatTensor([ + [x, y], + [x + self.pixel_per_grid, y], + [x + self.pixel_per_grid, y + self.pixel_per_grid], + [x, y + self.pixel_per_grid] + ]).to(self.device) + path = pydiffvg.Polygon(points=points, + stroke_width=torch.tensor(0.0), + is_closed=True) + + self.strokes_counter += 1 + return path + + def clip_curve_shape(self): + if self.style in ['sketch', 'ink']: + for group in self.shape_groups: + group.stroke_color.data[:3].clamp_(0., 0.) # to force black stroke + group.stroke_color.data[-1].clamp_(0., 1.) # clip alpha + else: + for group in self.shape_groups: + if group.stroke_color is not None: + group.stroke_color.data.clamp_(0.0, 1.0) # clip rgba + if group.fill_color is not None: + group.fill_color.data.clamp_(0.0, 1.0) # clip rgba + + def reinitialize_paths(self, + reinit_path: bool = False, + opacity_threshold: float = None, + area_threshold: float = None, + fpath: pathlib.Path = None): + """ + reinitialize paths, also known as 'Reinitializing paths' in VectorFusion paper. + + Args: + reinit_path: whether to reinitialize paths or not. + opacity_threshold: Threshold of opacity. + area_threshold: Threshold of the closed polygon area. + fpath: The path to save the reinitialized SVG. + """ + if not reinit_path: + return + if self.style not in ['iconography', 'low-poly', 'painting']: + return + + def get_keys_below_threshold(my_dict, threshold): + keys_below_threshold = [key for key, value in my_dict.items() if value < threshold] + return keys_below_threshold + + select_path_ids_by_opc = [] + select_path_ids_by_area = [] + if self.style in ['iconography', 'low-poly']: + # re-init by opacity_threshold + if opacity_threshold != 0 and opacity_threshold is not None: + opacity_record_ = {group.shape_ids.item(): group.fill_color[-1].item() + for group in self.cur_shape_groups} + # print("-> opacity_record: ", opacity_record_) + print("-> opacity_record: ", [f"{k}: {v:.3f}" for k, v in opacity_record_.items()]) + select_path_ids_by_opc = get_keys_below_threshold(opacity_record_, opacity_threshold) + print("select_path_ids_by_opc: ", select_path_ids_by_opc) + + # remove path by area_threshold + if area_threshold != 0 and area_threshold is not None: + area_records = [Polygon(shape.points.detach().cpu().numpy()).area for shape in self.cur_shapes] + # print("-> area_records: ", area_records) + print("-> area_records: ", ['%.2f' % i for i in area_records]) + for i, shape in enumerate(self.cur_shapes): + points_ = shape.points.detach().cpu().numpy() + if Polygon(points_).area < area_threshold: + select_path_ids_by_area.append(shape.id) + print("select_path_ids_by_area: ", select_path_ids_by_area) + + elif self.style in ['painting']: + # re-init by opacity_threshold + if opacity_threshold != 0 and opacity_threshold is not None: + opacity_record_ = {group.shape_ids.item(): group.stroke_color[-1].item() + for group in self.cur_shape_groups} + # print("-> opacity_record: ", opacity_record_) + print("-> opacity_record: ", [f"{k}: {v:.3f}" for k, v in opacity_record_.items()]) + select_path_ids_by_opc = get_keys_below_threshold(opacity_record_, opacity_threshold) + print("select_path_ids_by_opc: ", select_path_ids_by_opc) + + # re-init paths + reinit_union = list(set(select_path_ids_by_opc + select_path_ids_by_area)) + if len(reinit_union) > 0: + for i, path in enumerate(self.cur_shapes): + if path.id in reinit_union: + coord = [i, i] if self.style == 'low-poly' else None + self.cur_shapes[i] = self.get_path(coord=coord) + for i, group in enumerate(self.cur_shape_groups): + shp_ids = group.shape_ids.cpu().numpy().tolist() + if set(shp_ids).issubset(reinit_union): + if self.style in ['iconography', 'low-poly']: + fill_color_init = torch.FloatTensor(np.random.uniform(size=[4])) + fill_color_init[-1] = 1.0 + self.cur_shape_groups[i] = pydiffvg.ShapeGroup( + shape_ids=torch.tensor(list(shp_ids)), + fill_color=fill_color_init, + stroke_color=None) + elif self.style in ['painting']: + stroke_color_init = torch.FloatTensor(np.random.uniform(size=[4])) + stroke_color_init[-1] = 1.0 + self.cur_shape_groups[i] = pydiffvg.ShapeGroup( + shape_ids=torch.tensor([len(self.shapes) - 1]), + fill_color=None, + stroke_color=stroke_color_init + ) + # save reinit svg + self.pretty_save_svg(fpath) + + print("-" * 40) + + def calc_distance_weight(self, loss_weight_keep): + shapes_forsdf = copy.deepcopy(self.cur_shapes) + shape_groups_forsdf = copy.deepcopy(self.cur_shape_groups) + for si in shapes_forsdf: + si.stroke_width = torch.FloatTensor([0]).to(self.device) + for sg_idx, sgi in enumerate(shape_groups_forsdf): + sgi.fill_color = torch.FloatTensor([1, 1, 1, 1]).to(self.device) + sgi.shape_ids = torch.LongTensor([sg_idx]).to(self.device) + + sargs_forsdf = pydiffvg.RenderFunction.serialize_scene( + self.canvas_width, self.canvas_height, shapes_forsdf, shape_groups_forsdf + ) + _render = pydiffvg.RenderFunction.apply + with torch.no_grad(): + im_forsdf = _render(self.canvas_width, # width + self.canvas_height, # height + 2, # num_samples_x + 2, # num_samples_y + 0, # seed + None, + *sargs_forsdf) + + # use alpha channel is a trick to get 0-1 image + im_forsdf = (im_forsdf[:, :, 3]).detach().cpu().numpy() + loss_weight = get_sdf(im_forsdf, normalize='to1') + loss_weight += loss_weight_keep + loss_weight = np.clip(loss_weight, 0, 1) + loss_weight = torch.FloatTensor(loss_weight).to(self.device) + return loss_weight + + def set_point_parameters(self, id_delta=0): + self.point_vars = [] + for i, path in enumerate(self.cur_shapes): + path.id = i + id_delta # set point id + path.points.requires_grad = True + self.point_vars.append(path.points) + + def get_point_parameters(self): + return self.point_vars + + def set_color_parameters(self): + self.color_vars = [] + for i, group in enumerate(self.cur_shape_groups): + if group.fill_color is not None: + group.fill_color.requires_grad = True + self.color_vars.append(group.fill_color) + if group.stroke_color is not None: + group.stroke_color.requires_grad = True + self.color_vars.append(group.stroke_color) + + def get_color_parameters(self): + return self.color_vars + + def set_width_parameters(self): + # stroke`s width optimization + self.width_vars = [] + for i, path in enumerate(self.shapes): + path.stroke_width.requires_grad = True + self.width_vars.append(path.stroke_width) + + def get_width_parameters(self): + return self.width_vars + + def pretty_save_svg(self, filename, width=None, height=None, shapes=None, shape_groups=None): + width = self.canvas_width if width is None else width + height = self.canvas_height if height is None else height + shapes = self.shapes if shapes is None else shapes + shape_groups = self.shape_groups if shape_groups is None else shape_groups + + self.save_svg(filename, width, height, shapes, shape_groups, use_gamma=False, background=None) + + def load_svg(self, path_svg): + canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg) + return canvas_width, canvas_height, shapes, shape_groups + + +def get_sdf(phi, **kwargs): + import skfmm # local import + + phi = (phi - 0.5) * 2 + if (phi.max() <= 0) or (phi.min() >= 0): + return np.zeros(phi.shape).astype(np.float32) + sd = skfmm.distance(phi, dx=1) + + flip_negative = kwargs.get('flip_negative', True) + if flip_negative: + sd = np.abs(sd) + + truncate = kwargs.get('truncate', 10) + sd = np.clip(sd, -truncate, truncate) + # print(f"max sd value is: {sd.max()}") + + zero2max = kwargs.get('zero2max', True) + if zero2max and flip_negative: + sd = sd.max() - sd + elif zero2max: + raise ValueError + + normalize = kwargs.get('normalize', 'sum') + if normalize == 'sum': + sd /= sd.sum() + elif normalize == 'to1': + sd /= sd.max() + return sd + + +class SparseCoordInit: + + def __init__(self, pred, gt, format='[bs x c x 2D]', quantile_interval=200, nodiff_thres=0.1): + if torch.is_tensor(pred): + pred = pred.detach().cpu().numpy() + if torch.is_tensor(gt): + gt = gt.detach().cpu().numpy() + + if format == '[bs x c x 2D]': + self.map = ((pred[0] - gt[0]) ** 2).sum(0) + self.reference_gt = copy.deepcopy(np.transpose(gt[0], (1, 2, 0))) + elif format == ['[2D x c]']: + self.map = (np.abs(pred - gt)).sum(-1) + self.reference_gt = copy.deepcopy(gt[0]) + else: + raise ValueError + + # OptionA: Zero too small errors to avoid the error too small deadloop + self.map[self.map < nodiff_thres] = 0 + quantile_interval = np.linspace(0., 1., quantile_interval) + quantized_interval = np.quantile(self.map, quantile_interval) + # remove redundant + quantized_interval = np.unique(quantized_interval) + quantized_interval = sorted(quantized_interval[1:-1]) + self.map = np.digitize(self.map, quantized_interval, right=False) + self.map = np.clip(self.map, 0, 255).astype(np.uint8) + self.idcnt = {} + for idi in sorted(np.unique(self.map)): + self.idcnt[idi] = (self.map == idi).sum() + # remove smallest one to remove the correct region + self.idcnt.pop(min(self.idcnt.keys())) + + def __call__(self): + if len(self.idcnt) == 0: + h, w = self.map.shape + return [np.random.uniform(0, 1) * w, np.random.uniform(0, 1) * h] + + target_id = max(self.idcnt, key=self.idcnt.get) + _, component, cstats, ccenter = cv2.connectedComponentsWithStats( + (self.map == target_id).astype(np.uint8), + connectivity=4 + ) + # remove cid = 0, it is the invalid area + csize = [ci[-1] for ci in cstats[1:]] + target_cid = csize.index(max(csize)) + 1 + center = ccenter[target_cid][::-1] + coord = np.stack(np.where(component == target_cid)).T + dist = np.linalg.norm(coord - center, axis=1) + target_coord_id = np.argmin(dist) + coord_h, coord_w = coord[target_coord_id] + + # replace_sampling + self.idcnt[target_id] -= max(csize) + if self.idcnt[target_id] == 0: + self.idcnt.pop(target_id) + self.map[component == target_cid] = 0 + return [coord_w, coord_h] + + +class RandomCoordInit: + def __init__(self, canvas_width, canvas_height): + self.canvas_width, self.canvas_height = canvas_width, canvas_height + + def __call__(self): + w, h = self.canvas_width, self.canvas_height + return [np.random.uniform(0, 1) * w, np.random.uniform(0, 1) * h] + + +class NaiveCoordInit: + def __init__(self, pred, gt, format='[bs x c x 2D]', replace_sampling=True): + if isinstance(pred, torch.Tensor): + pred = pred.detach().cpu().numpy() + if isinstance(gt, torch.Tensor): + gt = gt.detach().cpu().numpy() + + if format == '[bs x c x 2D]': + self.map = ((pred[0] - gt[0]) ** 2).sum(0) + elif format == ['[2D x c]']: + self.map = ((pred - gt) ** 2).sum(-1) + else: + raise ValueError + self.replace_sampling = replace_sampling + + def __call__(self): + coord = np.where(self.map == self.map.max()) + coord_h, coord_w = coord[0][0], coord[1][0] + if self.replace_sampling: + self.map[coord_h, coord_w] = -1 + return [coord_w, coord_h] + + +class PainterOptimizer: + + def __init__(self, + renderer: Painter, + style: str, + num_iter: int, + lr_config: omegaconf.DictConfig, + trainable_bg: bool = False): + self.renderer = renderer + self.num_iter = num_iter + self.trainable_bg = trainable_bg + self.lr_config = lr_config + + # set optimized params via style + self.optim_point, self.optim_color, self.optim_width = { + "iconography": (True, True, False), + "pixelart": (False, True, False), + "low-poly": (True, True, False), + "sketch": (True, False, False), + "ink": (True, False, True), + "painting": (True, True, True) + }.get(style, (False, False, False)) + self.optim_bg = trainable_bg + + # set lr schedule + schedule_cfg = lr_config.schedule + if schedule_cfg.name == 'linear': + self.lr_lambda = LinearDecayWithKeepLRLambda(init_lr=lr_config.point, + keep_ratio=schedule_cfg.keep_ratio, + decay_every=self.num_iter, + decay_ratio=schedule_cfg.decay_ratio) + elif schedule_cfg.name == 'cosine': + self.lr_lambda = CosineWithWarmupLRLambda(num_steps=self.num_iter, + warmup_steps=schedule_cfg.warmup_steps, + warmup_start_lr=schedule_cfg.warmup_start_lr, + warmup_end_lr=schedule_cfg.warmup_end_lr, + cosine_end_lr=schedule_cfg.cosine_end_lr) + else: + print(f"{schedule_cfg.name} is not support.") + self.lr_lambda = None + + if style in ['pixelart', 'low-poly']: + pass + + self.point_optimizer = None + self.color_optimizer = None + self.width_optimizer = None + self.bg_optimizer = None + self.point_scheduler = None + + def init_optimizers(self, pid_delta: int = 0): + # optimizer + optim_cfg = self.lr_config.optim + optim_name = optim_cfg.name + + params = {} + if self.optim_point: + self.renderer.set_point_parameters(pid_delta) + params['point'] = self.renderer.get_point_parameters() + self.point_optimizer = get_optimizer(optim_name, params['point'], self.lr_config.point, optim_cfg) + + if self.optim_color: + self.renderer.set_color_parameters() + params['color'] = self.renderer.get_color_parameters() + self.color_optimizer = get_optimizer(optim_name, params['color'], self.lr_config.color, optim_cfg) + + if self.optim_width: + self.renderer.set_width_parameters() + params['width'] = self.renderer.get_width_parameters() + if len(params['width']) > 0: + self.width_optimizer = get_optimizer(optim_name, params['width'], self.lr_config.width, optim_cfg) + + if self.optim_bg: + self.renderer.para_bg.requires_grad = True + self.bg_optimizer = get_optimizer(optim_name, self.renderer.para_bg, self.lr_config.bg, optim_cfg) + + # lr schedule + if self.lr_lambda is not None and self.optim_point: + self.point_scheduler = LambdaLR(self.point_optimizer, lr_lambda=self.lr_lambda, last_epoch=-1) + + def update_lr(self): + if self.point_scheduler is not None: + self.point_scheduler.step() + + def zero_grad_(self): + if self.point_optimizer is not None: + self.point_optimizer.zero_grad() + if self.color_optimizer is not None: + self.color_optimizer.zero_grad() + if self.width_optimizer is not None: + self.width_optimizer.zero_grad() + if self.bg_optimizer is not None: + self.bg_optimizer.zero_grad() + + def step_(self): + if self.point_optimizer is not None: + self.point_optimizer.step() + if self.color_optimizer is not None: + self.color_optimizer.step() + if self.width_optimizer is not None: + self.width_optimizer.step() + if self.bg_optimizer is not None: + self.bg_optimizer.step() + + def get_lr(self) -> Dict: + lr = {} + if self.point_optimizer is not None: + lr['pnt'] = self.point_optimizer.param_groups[0]['lr'] + if self.color_optimizer is not None: + lr['clr'] = self.color_optimizer.param_groups[0]['lr'] + if self.width_optimizer is not None: + lr['wd'] = self.width_optimizer.param_groups[0]['lr'] + if self.bg_optimizer is not None: + lr['bg'] = self.bg_optimizer.param_groups[0]['lr'] + return lr + + +class LinearDecayWithKeepLRLambda: + """apply in LIVE stage""" + + def __init__(self, init_lr, keep_ratio, decay_every, decay_ratio): + self.init_lr = init_lr + self.keep_ratio = keep_ratio + self.decay_every = decay_every + self.decay_ratio = decay_ratio + + def __call__(self, n): + if n < self.keep_ratio * self.decay_every: + return self.init_lr + + decay_time = n // self.decay_every + decay_step = n % self.decay_every + lr_s = self.decay_ratio ** decay_time + lr_e = self.decay_ratio ** (decay_time + 1) + r = decay_step / self.decay_every + lr = lr_s * (1 - r) + lr_e * r + return lr + + +class CosineWithWarmupLRLambda: + """apply in fine-tuning stage""" + + def __init__(self, num_steps, warmup_steps, warmup_start_lr, warmup_end_lr, cosine_end_lr): + self.n_steps = num_steps + self.n_warmup = warmup_steps + self.warmup_start_lr = warmup_start_lr + self.warmup_end_lr = warmup_end_lr + self.cosine_end_lr = cosine_end_lr + + def __call__(self, n): + if n < self.n_warmup: + # linearly warmup + return self.warmup_start_lr + (n / self.n_warmup) * (self.warmup_end_lr - self.warmup_start_lr) + else: + # cosine decayed schedule + return self.cosine_end_lr + 0.5 * (self.warmup_end_lr - self.cosine_end_lr) * ( + 1 + math.cos(math.pi * (n - self.n_warmup) / (self.n_steps - self.n_warmup))) diff --git a/svgdreamer/pipelines/SVGDreamer_pipeline.py b/svgdreamer/pipelines/SVGDreamer_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..023faa34497f130d88fc13cb1cd08ce153a35508 --- /dev/null +++ b/svgdreamer/pipelines/SVGDreamer_pipeline.py @@ -0,0 +1,743 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: +import pathlib +from PIL import Image +from typing import AnyStr, Union, Tuple, List + +import omegaconf +import numpy as np +from tqdm.auto import tqdm +import torch +import torch.nn.functional as F +from torch.optim.lr_scheduler import LambdaLR +import torchvision +from torchvision import transforms +from skimage.color import rgb2gray + +from svgdreamer.libs import ModelState, get_optimizer +from svgdreamer.painter import (CompPainter, CompPainterOptimizer, xing_loss_fn, Painter, PainterOptimizer, + CosineWithWarmupLRLambda, VectorizedParticleSDSPipeline, DiffusionPipeline) +from svgdreamer.token2attn.attn_control import EmptyControl, AttentionStore +from svgdreamer.token2attn.ptp_utils import view_images +from svgdreamer.utils.plot import plot_img, plot_couple, plot_attn, save_image +from svgdreamer.utils import init_tensor_with_color, AnyPath, mkdir +from svgdreamer.svgtools import merge_svg_files, is_valid_svg +from svgdreamer.diffusers_warp import model2res + +import ImageReward as RM + + +class SVGDreamerPipeline(ModelState): + + def __init__(self, args): + assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"] + assert args.x.vpsd.n_particle >= args.x.vpsd.vsd_n_particle + assert args.x.vpsd.n_particle >= args.x.vpsd.phi_n_particle + assert args.x.vpsd.n_phi_sample >= 1 + + logdir_ = f"sd{args.seed}" \ + f"-{'vpsd' if args.skip_sive else 'sive'}" \ + f"-{args.x.style}" \ + f"-P{args.x.num_paths}" \ + f"{'-RePath' if args.x.path_reinit.use else ''}" + super().__init__(args, log_path_suffix=logdir_) + + """SIVE log dirs""" + self.sive_attn_dir = self.result_path / "SIVE_attn_logs" + self.mask_dir = self.result_path / "SIVE_mask_logs" + self.sive_init_dir = self.result_path / "SIVE_init_logs" + self.sive_final_dir = self.result_path / "SIVE_final_logs" + # fg dir + self.fg_png_logs_dir = self.result_path / "SIVE_fg_png_logs" + self.fg_svg_logs_dir = self.result_path / "SIVE_fg_svg_logs" + # bg dir + self.bg_png_logs_dir = self.result_path / "SIVE_bg_png_logs" + self.bg_svg_logs_dir = self.result_path / "SIVE_bg_svg_logs" + """VPSD log dirs""" + self.ft_png_logs_dir = self.result_path / "VPSD_png_logs" + self.ft_svg_logs_dir = self.result_path / "VPSD_svg_logs" + self.reinit_dir = self.result_path / "VPSD_reinit_logs" + self.ft_init_dir = self.result_path / "VPSD_init_logs" + self.phi_samples_dir = self.result_path / "VPSD_phi_sampling_logs" + + mkdir([self.sive_attn_dir, self.mask_dir, self.fg_png_logs_dir, self.fg_svg_logs_dir, self.sive_final_dir, + self.bg_png_logs_dir, self.bg_svg_logs_dir, self.sive_init_dir, self.reinit_dir, + self.ft_init_dir, self.phi_samples_dir, self.ft_png_logs_dir, self.ft_svg_logs_dir]) + + # make video log + self.make_video = self.args.mv + if self.make_video: + self.frame_idx = 0 + self.frame_log_dir = self.result_path / "frame_logs" + self.frame_log_dir.mkdir(parents=True, exist_ok=True) + + # torch Generator seed + self.g_device = torch.Generator(device=self.device).manual_seed(args.seed) + + # for convenience + self.style = self.x_cfg.style + self.im_size = self.x_cfg.image_size + self.sive_cfg = self.x_cfg.sive + self.sive_optim = self.x_cfg.sive_stage_optim + self.vpsd_cfg = self.x_cfg.vpsd + self.vpsd_optim = self.x_cfg.vpsd_stage_optim + + if self.style == "pixelart": + self.x_cfg.sive_stage_optim.lr_schedule = False + self.x_cfg.vpsd_stage_optim.lr_schedule = False + + def painterly_rendering(self, text_prompt: str, target_file: AnyPath = None): + # log prompts + self.print(f"prompt: {text_prompt}") + self.print(f"neg_prompt: {self.args.neg_prompt}\n") + + if self.args.skip_sive: + # mode 1: optimization with VPSD from scratch + self.print("optimization with VPSD from scratch...") + final_svg_path = None + elif target_file is not None: + # mode 2: load the SVG file and use VPSD finetune it (skip SIVE) + assert pathlib.Path(target_file).exists() and is_valid_svg(target_file) + self.print(f"load svg from {target_file} ...") + self.print(f"SVG fine-tuning via VPSD...") + final_svg_path = target_file + self.x_cfg.coord_init = 'sparse' + else: + # mode 3: SIVE + VPSD + final_svg_path = self.SIVE_stage(text_prompt) + self.x_cfg.path_svg = final_svg_path + self.print("\n SVG fine-tuning via VPSD...") + + self.VPSD_stage(text_prompt, final_svg_path) + self.close(msg="painterly rendering complete.") + + def SIVE_stage(self, text_prompt: str): + # init diffusion model + pipeline = DiffusionPipeline(self.x_cfg.sive_model_cfg, self.args.diffuser, self.device) + + merged_svg_paths = [] + for i in range(self.vpsd_cfg.n_particle): + select_sample_path = self.result_path / f'select_sample_{i}.png' + + # generate sample and attention map + fg_attn_map, bg_attn_map, controller = self.extract_ldm_attn(self.x_cfg.sive_model_cfg, + pipeline, + text_prompt, + select_sample_path, + self.sive_cfg.attn_cfg, + self.im_size, + self.args.token_ind) + # load selected file + select_img = self.target_file_preprocess(select_sample_path.as_posix()) + self.print(f"load target file from: {select_sample_path.as_posix()}") + + # get objects by attention map + fg_img, bg_img, fg_mask, bg_mask = self.extract_object(select_img, fg_attn_map, bg_attn_map, iter=i) + self.print(f"fg_img shape: {fg_img.shape}, bg_img: {bg_img.shape}") + + # background rendering + self.print(f"-> background rendering: ") + bg_render_path = self.component_rendering(tag=f'{i}_bg', + prompt=text_prompt, + target_img=bg_img, + mask=bg_mask, + attention_map=bg_attn_map, + canvas_size=(self.im_size, self.im_size), + render_cfg=self.sive_cfg.bg, + optim_cfg=self.sive_optim, + log_png_dir=self.bg_png_logs_dir, + log_svg_dir=self.bg_svg_logs_dir) + # foreground rendering + self.print(f"-> foreground rendering: ") + fg_render_path = self.component_rendering(tag=f'{i}_fg', + prompt=text_prompt, + target_img=fg_img, + mask=fg_mask, + attention_map=fg_attn_map, + canvas_size=(self.im_size, self.im_size), + render_cfg=self.sive_cfg.fg, + optim_cfg=self.sive_optim, + log_png_dir=self.fg_png_logs_dir, + log_svg_dir=self.fg_svg_logs_dir) + # merge foreground and background + merged_svg_path = self.result_path / f'SIVE_render_final_{i}.svg' + merge_svg_files( + svg_path_1=bg_render_path, + svg_path_2=fg_render_path, + merge_type='simple', + output_svg_path=merged_svg_path.as_posix(), + out_size=(self.im_size, self.im_size) + ) + merged_svg_paths.append(merged_svg_path) + + # empty attention record + controller.reset() + + # free the VRAM + del pipeline + torch.cuda.empty_cache() + + return merged_svg_paths + + def component_rendering(self, + tag: str, + prompt: AnyPath, + target_img: torch.Tensor, + mask: Union[np.ndarray, None], + attention_map: Union[np.ndarray, None], + canvas_size: Tuple[int, int], + render_cfg: omegaconf.DictConfig, + optim_cfg: omegaconf.DictConfig, + log_png_dir: pathlib.Path, + log_svg_dir: pathlib.Path): + + # set path_schedule + path_schedule = self.get_path_schedule(render_cfg.path_schedule, + render_cfg.schedule_each, + render_cfg.num_paths) + if render_cfg.style == 'pixelart': + path_schedule = [render_cfg.grid] + self.print(f"path_schedule: {path_schedule}") + + # for convenience + n_iter = render_cfg.num_iter + style = render_cfg.style + trainable_bg = render_cfg.optim_bg + total_step = len(path_schedule) * n_iter + + # set renderer + renderer = CompPainter(style, + target_img, + canvas_size, + render_cfg.num_segments, + render_cfg.segment_init, + render_cfg.radius, + render_cfg.grid, + render_cfg.width, + device=self.device, + attn_init=render_cfg.use_attn_init and attention_map is not None, + attention_map=attention_map, + attn_prob_tau=render_cfg.softmax_tau) + + if attention_map is not None: + # init fist control points by attention_map + attn_thresh, select_inds = renderer.attn_init_points(num_paths=sum(path_schedule), mask=mask) + # log attention, just once + plot_attn(attention_map, attn_thresh, target_img, select_inds, + (self.sive_attn_dir / f"attention_{tag}_map.jpg").as_posix()) + else: + # init fist control points by GT + renderer.component_wise_path_init(pred=None, init_type=render_cfg.coord_init) + + optimizer_list = [ + CompPainterOptimizer(renderer, style, n_iter, optim_cfg, trainable_bg) + for _ in range(len(path_schedule)) + ] + + pathn_record = [] + loss_weight_keep = 0 + step = 0 + loss_weight = 1 + with tqdm(initial=step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: + for path_idx, pathn in enumerate(path_schedule): + # record path + pathn_record.append(pathn) + # init graphic + img = renderer.init_image(num_paths=pathn) + plot_img(img, self.sive_init_dir, fname=f"{tag}_init_img_{path_idx}") + # rebuild optimizer + optimizer_list[path_idx].init_optimizers(pid_delta=int(path_idx * pathn)) + + pbar.write(f"=> adding {pathn} paths, n_path: {sum(pathn_record)}, " + f"n_point: {len(renderer.get_point_params())}, " + f"n_width: {len(renderer.get_width_params())}, " + f"n_color: {len(renderer.get_color_params())}") + + for t in range(n_iter): + raster_img = renderer.get_image(step=t).to(self.device) + + if render_cfg.use_distance_weighted_loss and style == "iconography": + loss_weight = renderer.calc_distance_weight(loss_weight_keep) + + # reconstruction loss + if style == "pixelart": + loss_recon = torch.nn.functional.l1_loss(raster_img, target_img) + else: + if render_cfg.use_distance_weighted_loss: + # UDF loss + loss_recon = ((raster_img - target_img) ** 2) + loss_recon = (loss_recon.sum(1) * loss_weight).mean() + else: + loss_recon = F.mse_loss(raster_img, target_img) + + # Xing Loss for Self-Interaction Problem + loss_xing = torch.tensor(0.) + if style == "iconography": + loss_xing = xing_loss_fn(renderer.get_point_params()) * render_cfg.xing_loss_weight + + # total loss + loss = loss_recon + loss_xing + + lr_str = "" + for k, lr in optimizer_list[path_idx].get_lr().items(): + lr_str += f"{k}_lr: {lr:.4f}, " + + pbar.set_description( + lr_str + + f"L_total: {loss.item():.4f}, " + f"L_recon: {loss_recon.item():.4f}, " + f"L_xing: {loss_xing.item():.4e}" + ) + + # optimization + for i in range(path_idx + 1): + optimizer_list[i].zero_grad_() + + loss.backward() + + for i in range(path_idx + 1): + optimizer_list[i].step_() + + renderer.clip_curve_shape() + + if render_cfg.lr_schedule: + for i in range(path_idx + 1): + optimizer_list[i].update_lr() + + if step % self.args.save_step == 0 and self.accelerator.is_main_process: + plot_couple(target_img, + raster_img, + step, + prompt=prompt, + output_dir=log_png_dir.as_posix(), + fname=f"{tag}_iter{step}") + renderer.save_svg(log_svg_dir / f"{tag}_svg_iter{step}.svg") + + step += 1 + pbar.update(1) + + if render_cfg.use_distance_weighted_loss and style == "iconography": + loss_weight_keep = loss_weight.detach().cpu().numpy() * 1 + # calc center + renderer.component_wise_path_init(raster_img) + + # end LIVE + final_svg_fpth = self.sive_final_dir / f"{tag}_final_render.svg" + renderer.save_svg(final_svg_fpth) + + return final_svg_fpth + + def VPSD_stage(self, text_prompt: AnyStr, init_svg_path: Union[List[AnyPath], AnyPath] = None): + # for convenience + guidance_cfg = self.x_cfg.vpsd + vpsd_model_cfg = self.x_cfg.vpsd_model_cfg + n_particle = guidance_cfg.n_particle + total_step = guidance_cfg.num_iter + path_reinit = self.x_cfg.path_reinit + + # init VPSD + pipeline = VectorizedParticleSDSPipeline(vpsd_model_cfg, self.args.diffuser, guidance_cfg, self.device) + # init reward model + reward_model = None + if guidance_cfg.phi_ReFL: + reward_model = RM.load("ImageReward-v1.0", device=self.device, download_root=self.x_cfg.reward_path) + + # create svg renderer + if isinstance(init_svg_path, List): + renderers = [self.load_renderer(init_path) for init_path in init_svg_path] + else: + renderers = [self.load_renderer(init_svg_path) for _ in range(n_particle)] + + if init_svg_path: # init from an SVG + target_img = self.target_file_preprocess(self.result_path / 'target_img.png') + else: # randomly init + if self.x_cfg.color_init == 'rand': # randomly init + target_img = torch.randn(1, 3, self.im_size, self.im_size) + self.print("color: randomly init") + else: # specified color + target_img = init_tensor_with_color(self.x_cfg.color_init, 1, self.im_size, self.im_size) + self.print(f"color: {self.x_cfg.color_init}") + plot_img(target_img, self.result_path, fname='target_img') + + # initialize the particles + for render in renderers: + render.component_wise_path_init(gt=target_img, pred=None, init_type=self.x_cfg.coord_init) + + # log init images + for i, r in enumerate(renderers): + init_imgs = r.init_image(stage=0, num_paths=self.x_cfg.num_paths) + plot_img(init_imgs, self.ft_init_dir, fname=f"init_img_stage_two_{i}") + + # init renderer optimizer + optimizers = [] + for renderer in renderers: + optim_ = PainterOptimizer(renderer, + self.style, + guidance_cfg.num_iter, + self.vpsd_optim, + self.x_cfg.trainable_bg) + optim_.init_optimizers() + optimizers.append(optim_) + + # init phi_model optimizer + phi_optimizer = get_optimizer('adamW', + pipeline.phi_params, + guidance_cfg.phi_lr, + guidance_cfg.phi_optim) + # init phi_model lr scheduler + phi_scheduler = None + schedule_cfg = guidance_cfg.phi_schedule + if schedule_cfg.use: + phi_lr_lambda = CosineWithWarmupLRLambda(num_steps=schedule_cfg.total_step, + warmup_steps=schedule_cfg.warmup_steps, + warmup_start_lr=schedule_cfg.warmup_start_lr, + warmup_end_lr=schedule_cfg.warmup_end_lr, + cosine_end_lr=schedule_cfg.cosine_end_lr) + phi_scheduler = LambdaLR(phi_optimizer, lr_lambda=phi_lr_lambda, last_epoch=-1) + + self.print(f"-> Painter point Params: {len(renderers[0].get_point_parameters())}") + self.print(f"-> Painter color Params: {len(renderers[0].get_color_parameters())}") + self.print(f"-> Painter width Params: {len(renderers[0].get_width_parameters())}") + + L_reward = torch.tensor(0.) + + self.step = 0 # reset global step + self.print(f"\ntotal VPSD optimization steps: {total_step}") + with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: + while self.step < total_step: + # set particles + particles = [renderer.get_image() for renderer in renderers] + raster_imgs = torch.cat(particles, dim=0) + + if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): + plot_img(raster_imgs, self.frame_log_dir, fname=f"iter{self.frame_idx}") + self.frame_idx += 1 + + L_guide, grad, latents, t_step = pipeline.variational_score_distillation( + raster_imgs, + self.step, + prompt=[text_prompt], + negative_prompt=self.args.neg_prompt, + grad_scale=guidance_cfg.grad_scale, + enhance_particle=guidance_cfg.particle_aug, + im_size=model2res(vpsd_model_cfg.model_id) + ) + + # Xing Loss for Self-Interaction Problem + L_add = torch.tensor(0.) + if self.style == "iconography" or self.x_cfg.xing_loss.use: + for r in renderers: + L_add += xing_loss_fn(r.get_point_parameters()) * self.x_cfg.xing_loss.weight + + loss = L_guide + L_add + + # optimization + for opt_ in optimizers: + opt_.zero_grad_() + loss.backward() + for opt_ in optimizers: + opt_.step_() + + # phi_model optimization + for _ in range(guidance_cfg.phi_update_step): + L_lora = pipeline.train_phi_model(latents, guidance_cfg.phi_t, as_latent=True) + + phi_optimizer.zero_grad() + L_lora.backward() + phi_optimizer.step() + + # reward learning + if guidance_cfg.phi_ReFL and self.step % guidance_cfg.phi_sample_step == 0: + with torch.no_grad(): + phi_outputs = [] + phi_sample_paths = [] + for idx in range(guidance_cfg.n_phi_sample): + phi_output = pipeline.sample(text_prompt, + num_inference_steps=guidance_cfg.phi_infer_step, + generator=self.g_device) + sample_path = (self.phi_samples_dir / f'iter{idx}.png').as_posix() + phi_output.images[0].save(sample_path) + phi_sample_paths.append(sample_path) + + phi_output_np = np.array(phi_output.images[0]) + phi_outputs.append(phi_output_np) + # save all samples + view_images(phi_outputs, save_image=True, + num_rows=max(len(phi_outputs) // 6, 1), + fp=self.phi_samples_dir / f'samples_iter{self.step}.png') + + ranking, rewards = reward_model.inference_rank(text_prompt, phi_sample_paths) + self.print(f"ranking: {ranking}, reward score: {rewards}") + + for k in range(guidance_cfg.n_phi_sample): + phi = self.target_file_preprocess(phi_sample_paths[ranking[k] - 1]) + L_reward = pipeline.train_phi_model_refl(phi, weight=rewards[k]) + + phi_optimizer.zero_grad() + L_reward.backward() + phi_optimizer.step() + + # update the learning rate of the phi_model + if phi_scheduler is not None: + phi_scheduler.step() + + # curve regularization + for r in renderers: + r.clip_curve_shape() + + # re-init paths + if self.step % path_reinit.freq == 0 and self.step < path_reinit.stop_step and self.step != 0: + for i, r in enumerate(renderers): + r.reinitialize_paths(path_reinit.use, # on-off + path_reinit.opacity_threshold, + path_reinit.area_threshold, + fpath=self.reinit_dir / f"reinit-{self.step}_p{i}.svg") + + # update lr + if self.vpsd_optim.lr_schedule: + for opt_ in optimizers: + opt_.update_lr() + + # log pretrained model lr + lr_str = "" + for k, lr in optimizers[0].get_lr().items(): + lr_str += f"{k}_lr: {lr:.4f}, " + # log phi model lr + cur_phi_lr = phi_optimizer.param_groups[0]['lr'] + lr_str += f"phi_lr: {cur_phi_lr:.3e}, " + + pbar.set_description( + lr_str + + f"t: {t_step.item():.2f}, " + f"L_total: {loss.item():.4f}, " + f"L_add: {L_add.item():.4e}, " + f"L_lora: {L_lora.item():.4f}, " + f"L_reward: {L_reward.item():.4f}, " + f"grad: {grad.item():.4e}" + ) + + if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: + # save png + torchvision.utils.save_image(raster_imgs, + fp=self.ft_png_logs_dir / f'iter{self.step}.png') + + # save svg + for i, r in enumerate(renderers): + r.pretty_save_svg(self.ft_svg_logs_dir / f"svg_iter{self.step}_p{i}.svg") + + self.step += 1 + pbar.update(1) + + # save final + for i, r in enumerate(renderers): + ft_svg_path = self.result_path / f"finetune_final_p_{i}.svg" + r.pretty_save_svg(ft_svg_path) + # save SVGs + torchvision.utils.save_image(raster_imgs, fp=self.result_path / f'all_particles.png') + + if self.make_video: + from subprocess import call + call([ + "ffmpeg", + "-framerate", f"{self.args.framerate}", + "-i", (self.frame_log_dir / "iter%d.png").as_posix(), + "-vb", "20M", + (self.result_path / "svgdreamer_rendering.mp4").as_posix() + ]) + + def load_renderer(self, path_svg=None): + renderer = Painter(self.args.diffvg, + self.style, + self.x_cfg.num_segments, + self.x_cfg.segment_init, + self.x_cfg.radius, + self.im_size, + self.x_cfg.grid, + self.x_cfg.trainable_bg, + self.x_cfg.width, + path_svg=path_svg, + device=self.device) + + # if load a svg file, then rasterize it + save_path = self.result_path / 'target_img.png' + if path_svg is not None and (not save_path.exists()): + canvas_width, canvas_height, shapes, shape_groups = renderer.load_svg(path_svg) + render_img = renderer.render_image(canvas_width, canvas_height, shapes, shape_groups) + torchvision.utils.save_image(render_img, fp=save_path) + return renderer + + def target_file_preprocess(self, tar_path: AnyPath): + process_comp = transforms.Compose([ + transforms.Resize(size=(self.im_size, self.im_size)), + transforms.ToTensor(), + transforms.Lambda(lambda t: t.unsqueeze(0)), + ]) + + tar_pil = Image.open(tar_path).convert("RGB") # open file + target_img = process_comp(tar_pil) # preprocess + target_img = target_img.to(self.device) + return target_img + + def extract_object(self, + select_img: torch.Tensor, + fg_attn_map: np.ndarray, + bg_attn_map: np.ndarray, + iter: Union[str, int], + tau: float = 0.15): + # attention to mask + bool_fg_attn_map = fg_attn_map > tau + fg_mask = bool_fg_attn_map.astype(int) # [w, h] + + # shrunk_mask + w, h = fg_mask.shape + fg_mask[1:w - 1, 1:h - 1] = fg_mask[1:w - 1, 1:h - 1] + + bg_mask = 1 - fg_mask + + # masked image, and save in place + select_img_np = select_img.cpu().numpy() + fg_img = fg_mask * select_img_np # [1, 3, w, h] + fg_mask_ = np.expand_dims(np.array([fg_mask, fg_mask, fg_mask]), axis=0) # [w,h] -> [1,3,w,h] + fg_img[fg_mask_ == 0] = 1 + fg_img = (fg_img / fg_img.max() * 255) + save_image(fg_img[0], self.mask_dir / f'{iter}_mask_fg.png') + + bg_img = bg_mask * select_img_np + bg_mask_ = np.expand_dims(np.array([bg_mask, bg_mask, bg_mask]), axis=0) + bg_img[bg_mask_ == 0] = 1 + bg_img = (bg_img / bg_img.max() * 255) + save_image(bg_img[0], self.mask_dir / f'{iter}_mask_bg.png') + + # to Tensor + fg_img_final = self.target_file_preprocess(self.mask_dir / f'{iter}_mask_fg.png') + bg_img_final = self.target_file_preprocess(self.mask_dir / f'{iter}_mask_bg.png') + + # [1,3,w,h] -> [w,h] + fg_mask = fg_mask_[0][0, :, :] + bg_mask = 1 - fg_mask + return fg_img_final, bg_img_final, fg_mask, bg_mask + + def extract_ldm_attn(self, + model_cfg: omegaconf.DictConfig, + pipeline: DiffusionPipeline, + prompts: str, + gen_sample_path: AnyPath, + attn_init_cfg: omegaconf.DictConfig, + image_size: int, + token_ind: int, + attn_init: bool = True, ): + if token_ind <= 0: + raise ValueError("The 'token_ind' should be greater than 0") + + # init controller + controller = AttentionStore() if attn_init else EmptyControl() + + # forward once and record attention map + height = width = model2res(model_cfg.model_id) + outputs = pipeline.sample(prompt=[prompts], + height=height, + width=width, + num_inference_steps=model_cfg.num_inference_steps, + controller=controller, + guidance_scale=model_cfg.guidance_scale, + negative_prompt=self.args.neg_prompt, + generator=self.g_device) + outputs_np = [np.array(img) for img in outputs.images] + view_images(outputs_np, save_image=True, fp=gen_sample_path) + self.print(f"select_sample shape: {outputs_np[0].shape}") + + if attn_init: + """ldm cross-attention map""" + cross_attention_maps, tokens = \ + pipeline.get_cross_attention([prompts], + controller, + res=attn_init_cfg.cross_attn_res, + from_where=("up", "down"), + save_path=self.sive_attn_dir / "cross_attn.png") + + self.print(f"the length of tokens is {len(tokens)}, select {token_ind}-th token") + # [res, res, seq_len] + self.print(f"origin cross_attn_map shape: {cross_attention_maps.shape}") + # [res, res] + cross_attn_map = cross_attention_maps[:, :, token_ind] + self.print(f"select cross_attn_map shape: {cross_attn_map.shape}") + cross_attn_map = 255 * cross_attn_map / cross_attn_map.max() + # [res, res, 3] + cross_attn_map = cross_attn_map.unsqueeze(-1).expand(*cross_attn_map.shape, 3) + # [3, res, res] + cross_attn_map = cross_attn_map.permute(2, 0, 1).unsqueeze(0) + # [3, clip_size, clip_size] + cross_attn_map = F.interpolate(cross_attn_map, size=image_size, mode='bicubic') + cross_attn_map = torch.clamp(cross_attn_map, min=0, max=255) + # rgb to gray + cross_attn_map = rgb2gray(cross_attn_map.squeeze(0).permute(1, 2, 0)).astype(np.float32) + # torch to numpy + if cross_attn_map.shape[-1] != image_size and cross_attn_map.shape[-2] != image_size: + cross_attn_map = cross_attn_map.reshape(image_size, image_size) + # to [0, 1] + cross_attn_map = (cross_attn_map - cross_attn_map.min()) / (cross_attn_map.max() - cross_attn_map.min()) + + """ldm self-attention map""" + self_attention_maps, svd, vh_ = \ + pipeline.get_self_attention_comp([prompts], + controller, + res=attn_init_cfg.self_attn_res, + from_where=("up", "down"), + img_size=image_size, + max_com=attn_init_cfg.max_com, + save_path=self.sive_attn_dir) + + # comp self-attention map + if attn_init_cfg.mean_comp: + self_attn = np.mean(vh_, axis=0) + self.print(f"use the mean of {attn_init_cfg.max_com} comps.") + else: + self_attn = vh_[attn_init_cfg.comp_idx] + self.print(f"select {attn_init_cfg.comp_idx}-th comp.") + # to [0, 1] + self_attn = (self_attn - self_attn.min()) / (self_attn.max() - self_attn.min()) + # visual final self-attention + self_attn_vis = np.copy(self_attn) + self_attn_vis = self_attn_vis * 255 + self_attn_vis = np.repeat(np.expand_dims(self_attn_vis, axis=2), 3, axis=2).astype(np.uint8) + view_images(self_attn_vis, save_image=True, fp=self.sive_attn_dir / "self-attn-final.png") + + """get final attention map""" + attn_map = attn_init_cfg.attn_coeff * cross_attn_map + (1 - attn_init_cfg.attn_coeff) * self_attn + # to [0, 1] + attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min()) + # visual fusion-attention + attn_map_vis = np.copy(attn_map) + attn_map_vis = attn_map_vis * 255 + attn_map_vis = np.repeat(np.expand_dims(attn_map_vis, axis=2), 3, axis=2).astype(np.uint8) + view_images(attn_map_vis, save_image=True, fp=self.sive_attn_dir / 'fusion-attn.png') + + # inverse fusion-attention to [0, 1] + inverse_attn = 1 - attn_map + # visual reversed fusion-attention + reversed_attn_map_vis = np.copy(inverse_attn) + reversed_attn_map_vis = reversed_attn_map_vis * 255 + reversed_attn_map_vis = np.repeat(np.expand_dims(reversed_attn_map_vis, axis=2), 3, axis=2).astype(np.uint8) + view_images(reversed_attn_map_vis, save_image=True, fp=self.sive_attn_dir / 'reversed-fusion-attn.png') + + self.print(f"-> fusion attn_map: {attn_map.shape}") + else: + attn_map = None + inverse_attn = None + + return attn_map, inverse_attn, controller + + def get_path_schedule(self, + path_schedule: str, + schedule_each: Union[int, List], + num_paths: int = None): + if path_schedule == 'repeat': + assert num_paths is not None + return int(num_paths / schedule_each) * [schedule_each] + elif path_schedule == 'list': + assert isinstance(schedule_each, list) or isinstance(schedule_each, omegaconf.ListConfig) + return schedule_each + else: + raise NotImplementedError diff --git a/svgdreamer/pipelines/__init__.py b/svgdreamer/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918 --- /dev/null +++ b/svgdreamer/pipelines/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: diff --git a/svgdreamer/svgtools/__init__.py b/svgdreamer/svgtools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec9083cd96c4ef4df6ab9ae8ffefc2d71c6e9491 --- /dev/null +++ b/svgdreamer/svgtools/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Copyright (c) 2023, XiMing Xing. +# License: MIT License + +from .tff import FONT_LIST +from .type import is_valid_svg +from .merge import merge_svg_files +from .process import delete_empty_path, add_def_tag + +__all__ = [ + 'is_valid_svg', + 'merge_svg_files', + 'FONT_LIST', + 'delete_empty_path', 'add_def_tag' +] diff --git a/svgdreamer/svgtools/merge.py b/svgdreamer/svgtools/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..ef16353417a811cfa2d53f84641d8367e8c9b704 --- /dev/null +++ b/svgdreamer/svgtools/merge.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: SVGDreamer - merge +# Copyright (c) 2023, XiMing Xing. +# License: MIT License +from typing import Tuple, AnyStr + +import omegaconf +from svgpathtools import svg2paths, wsvg + +from .type import is_valid_svg +from .shape import * + + +def merge_svg_files( + svg_path_1: AnyStr, + svg_path_2: AnyStr, + merge_type: str, + output_svg_path: AnyStr, + out_size: Tuple[int, int], # e.g.: (600, 600) +): + is_valid_svg(svg_path_1) + is_valid_svg(svg_path_2) + + # set merge ops + if merge_type.startswith('vert'): # Move up/down vertically + if '+' in merge_type: # move up + move_val = merge_type.split("+")[1] + move_val = int(move_val) + elif '-' in merge_type: # move down + move_val = merge_type.split("-")[1] + move_val = -int(move_val) + else: + raise NotImplemented(f'{merge_type} is invalid.') + + merge_svg_by_group(svg_path_1, svg_path_2, + cp_offset=(0, move_val), + svg_out=output_svg_path, out_size=out_size) + + elif merge_type.startswith('cp'): # Move all control points + if '+' in merge_type: + move_val = merge_type.split("+")[1] + move_val = int(move_val) + elif '-' in merge_type: + move_val = merge_type.split("-")[1] + move_val = -int(move_val) + else: + raise NotImplemented(f'{merge_type} is invalid.') + + merge_svg_by_cp(svg_path_1, svg_path_2, + p_offset=move_val, + svg_out=output_svg_path, out_size=out_size) + + elif merge_type == 'simple': # simply combine two SVG files + simple_merge(svg_path_1, svg_path_2, output_svg_path, out_size) + else: + raise NotImplemented(f'{str(merge_type)} is not support !') + + +def simple_merge(svg_path1, svg_path2, output_path, out_size): + # read svg to paths + paths1, attributes1 = svg2paths(svg_path1) + paths2, attributes2 = svg2paths(svg_path2) + # merge path and attributes + paths = paths1 + paths2 + attributes = attributes1 + attributes2 + # write merged svg + wsvg(paths, + attributes=attributes, + filename=output_path, + viewbox=f"0 0 {out_size[0]} {out_size[1]}") + + +def merge_svg_by_group( + svg_path_1: AnyStr, + svg_path_2: AnyStr, + cp_offset: Tuple[float, float], + svg_out: AnyStr, + out_size: Tuple[int, int], # e.g.: (600, 600) +): + # load svg_path_1 + tree1 = ET.parse(svg_path_1) + root1 = tree1.getroot() + # new group, and add paths form svg_path_1 + group1 = ET.Element('g') + for i, element in enumerate(root1.iter()): + element.tag = element.tag.split('}')[-1] + if element.tag in ['path', 'polygon']: + group1.append(element) + + # load svg_path_2 + tree2 = ET.parse(svg_path_2) + root2 = tree2.getroot() + # new group, and add paths form svg_path_2 + group2 = ET.Element('g') + for j, path in enumerate(root2.findall('.//{http://www.w3.org/2000/svg}path')): + # Remove the 'svg:' prefix from the tag name + path.tag = path.tag.split('}')[-1] + group2.append(path) + + # new svg + svg = ET.Element('svg', + xmlns="http://www.w3.org/2000/svg", + version='1.1', + width=str(out_size[0]), + height=str(out_size[1])) + + # control group2 + if 'transform' in group2.attrib: + group2.attrib['transform'] += f' translate({cp_offset[0]}, {cp_offset[1]})' + else: + group2.attrib['transform'] = f'translate({cp_offset[0]}, {cp_offset[1]})' + # add two group + svg.append(group1) + svg.append(group2) + # write svg + tree = ET.ElementTree(svg) + tree.write(svg_out, encoding='utf-8', xml_declaration=True) + + +def merge_svg_by_cp( + svg_path_1: AnyStr, + svg_path_2: AnyStr, + p_offset: float, + svg_out: AnyStr, + out_size: Tuple[int, int], # e.g.: (600, 600) +): + # load svg_path_1 + tree1 = ET.parse(svg_path_1) + root1 = tree1.getroot() + # new group, and add paths form svg_path_1 + group1 = ET.Element('g') + for i, element in enumerate(root1.iter()): + element.tag = element.tag.split('}')[-1] + if element.tag in ['path', 'polygon']: + group1.append(element) + + # load svg_path_2 + tree2 = ET.parse(svg_path_2) + root2 = tree2.getroot() + + # new group, and add paths form svg_path_2 + group2 = ET.Element('g') + for j, path in enumerate(root2.findall('.//{http://www.w3.org/2000/svg}path')): + # remove the 'svg:' prefix from the tag name + path.tag = path.tag.split('}')[-1] + + d = path.get('d') + # parse paths + path_data = d.split() + new_path_data = [] + + for i in range(len(path_data)): + if path_data[i].replace('.', '').isdigit(): # get point coordinates + new_param = float(path_data[i]) + p_offset + new_path_data.append(str(new_param)) + else: + new_path_data.append(path_data[i]) + # update new d attrs + path.set('d', ' '.join(new_path_data)) + + group2.append(path) + + # new svg + svg = ET.Element('svg', + xmlns="http://www.w3.org/2000/svg", + version='1.1', + width=str(out_size[0]), + height=str(out_size[1])) + + # add two group + svg.append(group1) + svg.append(group2) + # write svg + tree = ET.ElementTree(svg) + tree.write(svg_out, encoding='utf-8', xml_declaration=True) + + +def merge_two_svgs_edit( + svg_path_1: AnyStr, + svg_path_2: AnyStr, + def_cfg: omegaconf.DictConfig, + p2_offset: Tuple[float, float], + svg_out: AnyStr, + out_size: Tuple[int, int], # e.g.: (600, 600) +): + # load svg_path_1 + tree1 = ET.parse(svg_path_1) + root1 = tree1.getroot() + # new group, and add paths form svg_path_1 + group1 = ET.Element('g') + for i, element in enumerate(root1.iter()): + element.tag = element.tag.split('}')[-1] + if element.tag in ['path', 'polygon']: + group1.append(element) + + # load svg_path_2 + tree2 = ET.parse(svg_path_2) + root2 = tree2.getroot() + + # new group, and add paths form svg_path_2 + group2 = ET.Element('g') + for j, path in enumerate(root2.findall('.//{http://www.w3.org/2000/svg}path')): + # remove the 'svg:' prefix from the tag name + path.tag = path.tag.split('}')[-1] + + d = path.get('d') + # parse paths + path_data = d.split() + new_path_data = [] + + d_idx = 0 # count digit + for i in range(len(path_data)): + if path_data[i].replace('.', '').isdigit(): # get point coordinates + d_idx += 1 + if d_idx % 2 == 1: # update y + new_param = float(path_data[i]) + (p2_offset[1]) + new_path_data.append(str(new_param)) + else: + new_path_data.append(path_data[i]) + else: + new_path_data.append(path_data[i]) + # update new d attrs + path.set('d', ' '.join(new_path_data)) + + group2.append(path) + + # new svg + svg = ET.Element('svg', + xmlns="http://www.w3.org/2000/svg", + version='1.1', + width=str(out_size[0]), + height=str(out_size[1])) + + # add two group + svg.append(group1) + svg.append(group2) + # write svg + tree = ET.ElementTree(svg) + tree.write(svg_out, encoding='utf-8', xml_declaration=True) diff --git a/svgdreamer/svgtools/process.py b/svgdreamer/svgtools/process.py new file mode 100644 index 0000000000000000000000000000000000000000..734b5a061918f1ef1ae016b4b4d387926c1c583c --- /dev/null +++ b/svgdreamer/svgtools/process.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: process +# Copyright (c) 2023, XiMing Xing. +# License: MIT License + +import xml.etree.ElementTree as ET +from typing import Tuple + +import omegaconf + +from .shape import circle_tag, rect_tag +from .type import is_valid_svg + +def delete_empty_path(input_svg: str, output_svg: str): + is_valid_svg(input_svg) + + # read svg + tree = ET.parse(input_svg) + root = tree.getroot() + + group = ET.Element('g') + for i, element in enumerate(root.iter()): + element.tag = element.tag.split('}')[-1] + if element.tag == 'path': + if element.get('d') == 'C NaN NaN' or element.get('d') == '': + continue + group.append(element) + + # new svg + svg = ET.Element('svg', + xmlns="http://www.w3.org/2000/svg", + version='1.1', + width=root.get('width'), + height=root.get('height'), + viewBox=root.get('viewBox')) + svg.append(group) + tree = ET.ElementTree(svg) + tree.write(output_svg, encoding='utf-8', xml_declaration=True) + + +def add_clipPath2def(mounted_node: ET.Element, tag_name: str, attrs: omegaconf.DictConfig): + # add defs node + defs = ET.SubElement(mounted_node, 'defs') # parent=mounted_node, tag='defs' + if tag_name == 'none': + return None + # add clipPath node + id = 'def_clip' + _circleClip = ET.SubElement(defs, 'clipPath', id='def_clip') # parent=defs, tag='clipPath' + # add ops + if tag_name == 'circle_clip': + _circleClip.append( + circle_tag(cx=attrs.cx, cy=attrs.cy, r=attrs.r) + ) + elif tag_name == 'rect_clip': + _circleClip.append( + rect_tag(x=attrs.x, y=attrs.y, rx=attrs.rx, ry=attrs.ry, width=attrs.width, height=attrs.height) + ) + else: + raise NotImplementedError(f'{tag_name} is not exist!') + return id + + +def add_def_tag( + svg_path: str, + def_tag_plan: str, + out_size: Tuple[int, int], # e.g.: (600, 600) +): + is_valid_svg(svg_path) + + width, height = out_size[0], out_size[1] + + # set def tag + if def_tag_plan == 'circle_clip': + def_cfg = omegaconf.DictConfig({ + 'name': 'circle_clip', + 'attrs': {'cx': width // 2, 'cy': height // 2, 'r': int(height * 0.5)} + }) + elif def_tag_plan == 'rect_clip': + def_cfg = omegaconf.DictConfig({ + 'name': 'rect_clip', + 'attrs': {'x': 0, 'y': 0, 'rx': 70, 'ry': 70, 'width': width, 'height': height} + }) + else: + def_cfg = None + + # load SVG + tree = ET.parse(svg_path) + root = tree.getroot() + # new group, and add paths form svg_path_1 + group = ET.Element('g') + for i, element in enumerate(root.iter()): + element.tag = element.tag.split('}')[-1] + if element.tag in ['path', 'polygon']: + group.append(element) + + # new svg + svg = ET.Element('svg', + xmlns="http://www.w3.org/2000/svg", + version='1.1', + width=str(out_size[0]), + height=str(out_size[1])) + # add def tag to the SVG + clip_id = add_clipPath2def(mounted_node=svg, + tag_name=def_cfg.name, + attrs=def_cfg.attrs) + group.set('clip-path', f'url(#{clip_id})') + svg.append(group) + # write svg + tree = ET.ElementTree(svg) + tree.write(svg_path, encoding='utf-8', xml_declaration=True) diff --git a/svgdreamer/svgtools/shape.py b/svgdreamer/svgtools/shape.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a8bc4f80a27b20a51ae50195ca37f93397b6a4 --- /dev/null +++ b/svgdreamer/svgtools/shape.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: SVGDreamer - shape +# Copyright (c) 2023, XiMing Xing. +# License: MIT License + +import xml.etree.ElementTree as ET + + +def circle_tag(cx: float, cy: float, r: float, transform: str = None): + attrib = { + 'cx': f'{cx}', 'cy': f'{cy}', 'r': f'{r}' + } + if transform is not None: + attrib['transform'] = transform + _circle = ET.Element('circle', attrib) # tag, attrib + return _circle + + +def rect_tag( + x: float, y: float, rx: float, ry: float, + width: float = 600, height: float = 600, + transform: str = None +): + attrib = { + 'x': f'{x}', 'y': f'{y}', 'rx': f'{rx}', 'ry': f'{ry}', + 'width': f'{width}', 'height': f'{height}' + } + if transform is not None: + attrib['transform'] = transform + _rect = ET.Element('rect', attrib) # tag, attrib + return _rect diff --git a/svgdreamer/svgtools/type.py b/svgdreamer/svgtools/type.py new file mode 100644 index 0000000000000000000000000000000000000000..c22479c62f7909ebb7a09e42e7e764a089efe0d1 --- /dev/null +++ b/svgdreamer/svgtools/type.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: SVGDreamer - type checking +# Copyright (c) 2023, XiMing Xing. +# License: MIT License + +from typing import AnyStr + +import xml.etree.ElementTree as ET + + +def is_valid_svg(file_path: AnyStr) -> bool: + try: + tree = ET.parse(file_path) + root = tree.getroot() + if root.tag.endswith('svg') and 'xmlns' in root.attrib: + return True + else: + return False + except ET.ParseError: + return False diff --git a/svgdreamer/token2attn/__init__.py b/svgdreamer/token2attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad761f2f5443eb41b15afc4116a66ecdfa9d918 --- /dev/null +++ b/svgdreamer/token2attn/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: diff --git a/svgdreamer/token2attn/attn_control.py b/svgdreamer/token2attn/attn_control.py new file mode 100644 index 0000000000000000000000000000000000000000..232d910eab5eeb53749afe0e57cda744d0f1341b --- /dev/null +++ b/svgdreamer/token2attn/attn_control.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- +# Copyright (c) XiMing Xing. All rights reserved. +# Author: XiMing Xing +# Description: + +from abc import ABC, abstractmethod +from typing import Optional, Union, Tuple, List, Dict + +import torch +import torch.nn.functional as F + +from .ptp_utils import (get_word_inds, get_time_words_attention_alpha) +from .seq_aligner import (get_replacement_mapper, get_refinement_mapper) + + +class AttentionControl(ABC): + + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + def step_callback(self, x_t): + # can be used to return a modified attention map + return x_t + + def between_steps(self): + return + + @property + def num_uncond_att_layers(self): + return 0 + + @abstractmethod + def forward(self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + h = attn.shape[0] + attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + return attn + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + +class EmptyControl(AttentionControl): + + def forward(self, attn, is_cross: bool, place_in_unet: str): + return attn + + +class AttentionStore(AttentionControl): + + def __init__(self): + super(AttentionStore, self).__init__() + self.step_store = self.get_empty_store() + self.attention_store = {} + + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + "down_self": [], "mid_self": [], "up_self": []} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= 32 ** 2: # avoid memory overhead + self.step_store[key].append(attn) + return attn + + def between_steps(self): + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + self.step_store = self.get_empty_store() + + def get_average_attention(self): + print(f"step count: {self.cur_step}") + average_attention = { + key: [item / self.cur_step for item in self.attention_store[key]] + for key in self.attention_store + } + return average_attention + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store = {} + + +class LocalBlend: + + def __init__(self, + prompts: List[str], + words: [List[List[str]]], + tokenizer, + device, + threshold=.3, + max_num_words=77): + self.max_num_words = max_num_words + + alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words) + for i, (prompt, words_) in enumerate(zip(prompts, words)): + if type(words_) is str: + words_ = [words_] + for word in words_: + ind = get_word_inds(prompt, word, tokenizer) + alpha_layers[i, :, :, :, :, ind] = 1 + self.alpha_layers = alpha_layers.to(device) + self.threshold = threshold + + def __call__(self, x_t, attention_store): + k = 1 + maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] + maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) for item in maps] + maps = torch.cat(maps, dim=1) + maps = (maps * self.alpha_layers).sum(-1).mean(1) + mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) + mask = F.interpolate(mask, size=(x_t.shape[2:])) + mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] + mask = mask.gt(self.threshold) + mask = (mask[:1] + mask[1:]).float() + x_t = x_t[:1] + mask * (x_t - x_t[:1]) + return x_t + + +class AttentionControlEdit(AttentionStore, ABC): + + def __init__(self, + prompts, + num_steps: int, + cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], + self_replace_steps: Union[float, Tuple[float, float]], + local_blend: Optional[LocalBlend], + tokenizer, + device): + super(AttentionControlEdit, self).__init__() + self.tokenizer = tokenizer + self.device = device + + self.batch_size = len(prompts) + self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, + self.tokenizer).to(self.device) + if type(self_replace_steps) is float: + self_replace_steps = 0, self_replace_steps + self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) + self.local_blend = local_blend # define outside + + def step_callback(self, x_t): + if self.local_blend is not None: + x_t = self.local_blend(x_t, self.attention_store) + return x_t + + def replace_self_attention(self, attn_base, att_replace): + if att_replace.shape[2] <= 16 ** 2: + return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) + else: + return att_replace + + @abstractmethod + def replace_cross_attention(self, attn_base, att_replace): + raise NotImplementedError + + def forward(self, attn, is_cross: bool, place_in_unet: str): + super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) + # FIXME not replace correctly + if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): + h = attn.shape[0] // (self.batch_size) + attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) + attn_base, attn_repalce = attn[0], attn[1:] + if is_cross: + alpha_words = self.cross_replace_alpha[self.cur_step] + attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + ( + 1 - alpha_words) * attn_repalce + attn[1:] = attn_repalce_new + else: + attn[1:] = self.replace_self_attention(attn_base, attn_repalce) + attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) + return attn + + +class AttentionReplace(AttentionControlEdit): + + def __init__(self, + prompts, + num_steps: int, + cross_replace_steps: float, + self_replace_steps: float, + local_blend: Optional[LocalBlend] = None, + tokenizer=None, + device=None): + super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, + local_blend, tokenizer, device) + self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device) + + def replace_cross_attention(self, attn_base, att_replace): + return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) + + +class AttentionRefine(AttentionControlEdit): + + def __init__(self, + prompts, + num_steps: int, + cross_replace_steps: float, + self_replace_steps: float, + local_blend: Optional[LocalBlend] = None, + tokenizer=None, + device=None): + super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, + local_blend, tokenizer, device) + self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer) + self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device) + self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) + + def replace_cross_attention(self, attn_base, att_replace): + attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) + attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) + return attn_replace + + +class AttentionReweight(AttentionControlEdit): + + def __init__(self, + prompts, + num_steps: int, + cross_replace_steps: float, + self_replace_steps: float, + equalizer, + local_blend: Optional[LocalBlend] = None, + controller: Optional[AttentionControlEdit] = None, + tokenizer=None, + device=None): + super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, + local_blend, tokenizer, device) + self.equalizer = equalizer.to(self.device) + self.prev_controller = controller + + def replace_cross_attention(self, attn_base, att_replace): + if self.prev_controller is not None: + attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) + attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] + return attn_replace + + +def get_equalizer(tokenizer, text: str, + word_select: Union[int, Tuple[int, ...]], + values: Union[List[float], Tuple[float, ...]]): + if type(word_select) is int or type(word_select) is str: + word_select = (word_select,) + equalizer = torch.ones(len(values), 77) + values = torch.tensor(values, dtype=torch.float32) + for word in word_select: + inds = get_word_inds(text, word, tokenizer) + equalizer[:, inds] = values + return equalizer diff --git a/svgdreamer/token2attn/ptp_utils.py b/svgdreamer/token2attn/ptp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70b9847f2b6f965b37513ed342f88d3a646fecd4 --- /dev/null +++ b/svgdreamer/token2attn/ptp_utils.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +import pathlib +from typing import Union, Optional, List, Tuple, Dict, Text, BinaryIO +from PIL import Image + +import torch +import cv2 +import numpy as np +import matplotlib.pyplot as plt + +from .seq_aligner import get_word_inds + + +def text_under_image(image: np.ndarray, + text: str, + text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray: + h, w, c = image.shape + offset = int(h * .2) + img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 + font = cv2.FONT_HERSHEY_SIMPLEX + img[:h] = image + textsize = cv2.getTextSize(text, font, 1, 2)[0] + text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 + cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) + return img + + +def view_images( + images: Union[np.ndarray, List[np.ndarray]], + num_rows: int = 1, + offset_ratio: float = 0.02, + save_image: bool = False, + fp: Union[Text, pathlib.Path, BinaryIO] = None, +) -> np.ndarray: + if save_image: + assert fp is not None + + if isinstance(images, list): + images = np.concatenate(images, axis=0) + + if isinstance(images, np.ndarray) and images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] if not isinstance(images, list) else images + num_empty = len(images) % num_rows + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + # Calculate the composite image + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = int(np.ceil(num_items / num_rows)) # count the number of columns + image_h = h * num_rows + offset * (num_rows - 1) + image_w = w * num_cols + offset * (num_cols - 1) + assert image_h > 0, "Invalid image height: {} (num_rows={}, offset_ratio={}, num_items={})".format( + image_h, num_rows, offset_ratio, num_items) + assert image_w > 0, "Invalid image width: {} (num_cols={}, offset_ratio={}, num_items={})".format( + image_w, num_cols, offset_ratio, num_items) + image_ = np.ones((image_h, image_w, 3), dtype=np.uint8) * 255 + + # Ensure that the last row is filled with empty images if necessary + if len(images) % num_cols > 0: + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + num_empty = num_cols - len(images) % num_cols + images += [empty_images] * num_empty + + for i in range(num_rows): + for j in range(num_cols): + k = i * num_cols + j + if k >= num_items: + break + image_[i * (h + offset): i * (h + offset) + h, j * (w + offset): j * (w + offset) + w] = images[k] + + pil_img = Image.fromarray(image_) + if save_image: + pil_img.save(fp) + return pil_img + + +def update_alpha_time_word(alpha, + bounds: Union[float, Tuple[float, float]], + prompt_ind: int, + word_inds: Optional[torch.Tensor] = None): + if isinstance(bounds, float): + bounds = 0, bounds + start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) + if word_inds is None: + word_inds = torch.arange(alpha.shape[2]) + alpha[: start, prompt_ind, word_inds] = 0 + alpha[start: end, prompt_ind, word_inds] = 1 + alpha[end:, prompt_ind, word_inds] = 0 + return alpha + + +def get_time_words_attention_alpha(prompts, num_steps, + cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], + tokenizer, + max_num_words=77): + if type(cross_replace_steps) is not dict: + cross_replace_steps = {"default_": cross_replace_steps} + if "default_" not in cross_replace_steps: + cross_replace_steps["default_"] = (0., 1.) + alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) + for i in range(len(prompts) - 1): + alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], + i) + for key, item in cross_replace_steps.items(): + if key != "default_": + inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] + for i, ind in enumerate(inds): + if len(ind) > 0: + alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) + alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) + return alpha_time_words diff --git a/svgdreamer/token2attn/seq_aligner.py b/svgdreamer/token2attn/seq_aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..d534d8ae1b6618604c619d56250293f66c0430f5 --- /dev/null +++ b/svgdreamer/token2attn/seq_aligner.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +import torch +import numpy as np + + +class ScoreParams: + + def __init__(self, gap, match, mismatch): + self.gap = gap + self.match = match + self.mismatch = mismatch + + def mis_match_char(self, x, y): + if x != y: + return self.mismatch + else: + return self.match + + +def get_matrix(size_x, size_y, gap): + matrix = [] + for i in range(len(size_x) + 1): + sub_matrix = [] + for j in range(len(size_y) + 1): + sub_matrix.append(0) + matrix.append(sub_matrix) + for j in range(1, len(size_y) + 1): + matrix[0][j] = j * gap + for i in range(1, len(size_x) + 1): + matrix[i][0] = i * gap + return matrix + + +def get_matrix(size_x, size_y, gap): + matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) + matrix[0, 1:] = (np.arange(size_y) + 1) * gap + matrix[1:, 0] = (np.arange(size_x) + 1) * gap + return matrix + + +def get_traceback_matrix(size_x, size_y): + matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) + matrix[0, 1:] = 1 + matrix[1:, 0] = 2 + matrix[0, 0] = 4 + return matrix + + +def global_align(x, y, score): + matrix = get_matrix(len(x), len(y), score.gap) + trace_back = get_traceback_matrix(len(x), len(y)) + for i in range(1, len(x) + 1): + for j in range(1, len(y) + 1): + left = matrix[i, j - 1] + score.gap + up = matrix[i - 1, j] + score.gap + diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) + matrix[i, j] = max(left, up, diag) + if matrix[i, j] == left: + trace_back[i, j] = 1 + elif matrix[i, j] == up: + trace_back[i, j] = 2 + else: + trace_back[i, j] = 3 + return matrix, trace_back + + +def get_aligned_sequences(x, y, trace_back): + x_seq = [] + y_seq = [] + i = len(x) + j = len(y) + mapper_y_to_x = [] + while i > 0 or j > 0: + if trace_back[i, j] == 3: + x_seq.append(x[i - 1]) + y_seq.append(y[j - 1]) + i = i - 1 + j = j - 1 + mapper_y_to_x.append((j, i)) + elif trace_back[i][j] == 1: + x_seq.append('-') + y_seq.append(y[j - 1]) + j = j - 1 + mapper_y_to_x.append((j, -1)) + elif trace_back[i][j] == 2: + x_seq.append(x[i - 1]) + y_seq.append('-') + i = i - 1 + elif trace_back[i][j] == 4: + break + mapper_y_to_x.reverse() + return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) + + +def get_mapper(x: str, y: str, tokenizer, max_len=77): + x_seq = tokenizer.encode(x) + y_seq = tokenizer.encode(y) + score = ScoreParams(0, 1, -1) + matrix, trace_back = global_align(x_seq, y_seq, score) + mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] + alphas = torch.ones(max_len) + alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() + mapper = torch.zeros(max_len, dtype=torch.int64) + mapper[:mapper_base.shape[0]] = mapper_base[:, 1] + mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) + return mapper, alphas + + +def get_refinement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers, alphas = [], [] + for i in range(1, len(prompts)): + mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + alphas.append(alpha) + return torch.stack(mappers), torch.stack(alphas) + + +def get_word_inds(text: str, word_place: int, tokenizer): + split_text = text.split(" ") + if type(word_place) is str: + word_place = [i for i, word in enumerate(split_text) if word_place == word] + elif type(word_place) is int: + word_place = [word_place] + out = [] + if len(word_place) > 0: + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + cur_len, ptr = 0, 0 + + for i in range(len(words_encode)): + cur_len += len(words_encode[i]) + if ptr in word_place: + out.append(i + 1) + if cur_len >= len(split_text[ptr]): + ptr += 1 + cur_len = 0 + return np.array(out) + + +def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): + words_x = x.split(' ') + words_y = y.split(' ') + if len(words_x) != len(words_y): + raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" + f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") + inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] + inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] + inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] + mapper = np.zeros((max_len, max_len)) + i = j = 0 + cur_inds = 0 + while i < max_len and j < max_len: + if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: + inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] + if len(inds_source_) == len(inds_target_): + mapper[inds_source_, inds_target_] = 1 + else: + ratio = 1 / len(inds_target_) + for i_t in inds_target_: + mapper[inds_source_, i_t] = ratio + cur_inds += 1 + i += len(inds_source_) + j += len(inds_target_) + elif cur_inds < len(inds_source): + mapper[i, j] = 1 + i += 1 + j += 1 + else: + mapper[j, j] = 1 + i += 1 + j += 1 + + return torch.from_numpy(mapper).float() + + +def get_replacement_mapper(prompts, tokenizer, max_len=77): + x_seq = prompts[0] + mappers = [] + for i in range(1, len(prompts)): + mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) + mappers.append(mapper) + return torch.stack(mappers) diff --git a/svgdreamer/utils/__init__.py b/svgdreamer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c46ae0d33ca7d6efeaee7463a827532a86a33fa8 --- /dev/null +++ b/svgdreamer/utils/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: __init__.py +# Copyright (c) 2023, XiMing Xing. +# License: MPL-2.0 License + +from .misc import * +from .color_attrs import get_rgb_from_color, init_tensor_with_color diff --git a/svgdreamer/utils/color_attrs.py b/svgdreamer/utils/color_attrs.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ca1c1df8c6e86981141e07c17d79c25feb6612 --- /dev/null +++ b/svgdreamer/utils/color_attrs.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: shape_group +# Copyright (c) 2023, XiMing Xing. +# License: MIT License + +from typing import Tuple + +import torch +from matplotlib import colors + + +def init_tensor_with_rgb( + rgb: Tuple[float, float, float], + b: int, + w: int, + h: int, + norm: bool = False +): + """ + Initializes a PyTorch tensor with the specified RGB values. The tensor has shape (b, 3, w, h). + + Args: + rgb: RGB values, shape (3,) + b: Batch size + w: Width + h: Height + norm: normalize the tensor to range [0, 1] + + Examples: + >>> rgb = (0.5, 0.2, 0.1) # Specify RGB values + >>> tensor = init_tensor_with_rgb(rgb, 1, 100, 100, norm=False) # Initialize tensor + + Returns: + Initialized tensor + """ + + # Convert RGB values to tensor + rgb = torch.tensor(rgb, dtype=torch.float) + + # Create tensor + tensor = torch.zeros((b, 3, w, h), dtype=torch.float) + + # Assign RGB values to tensor + tensor[:, 0] = rgb[0] + tensor[:, 1] = rgb[1] + tensor[:, 2] = rgb[2] + + if norm: + tensor = tensor / 255. + + return tensor + + +def init_tensor_with_color( + color: str, + b: int, + w: int, + h: int, + norm: bool = True +): + """ + Initializes a PyTorch tensor with the specified RGB values. The tensor has shape (b, 3, w, h). + + Args: + color: + b: Batch size + w: Width + h: Height + norm: normalize the tensor to range [0, 1] + + Examples: + >>> color = '#B0A695' # Specify RGB values + >>> tensor = init_tensor_with_rgb(color, 1, 100, 100) # Initialize tensor + + Returns: + Initialized tensor + """ + + rgb = get_rgb_from_color(color) + + # Convert RGB values to tensor + rgb = torch.tensor(rgb, dtype=torch.float) + + # Create tensor + tensor = torch.zeros((b, 3, w, h), dtype=torch.float) + + # Assign RGB values to tensor + tensor[:, 0] = rgb[0] + tensor[:, 1] = rgb[1] + tensor[:, 2] = rgb[2] + + return tensor + + +def hex_to_rgb(hex_code): + r = int(hex_code[0:2], 16) + g = int(hex_code[2:4], 16) + b = int(hex_code[4:6], 16) + return (r, g, b) + + +def get_rgb_from_color(color: str): + # get the corresponding RGB value based on the color + if color.startswith('#'): + color = color.split('#')[1] + rgb = hex_to_rgb(color) + rgb = [c / 255. for c in rgb] # to [0, 1] + elif color in colors.cnames: + rgb = colors.to_rgb(color) + else: + rgb = color + return rgb + + +if __name__ == "__main__": + color = '#B0A695' + + rgb = get_rgb_from_color(color) + + print(rgb) diff --git a/svgdreamer/utils/inpaint_util.py b/svgdreamer/utils/inpaint_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bbe52967aaa28a6cd697aa7b725b5dccf74414 --- /dev/null +++ b/svgdreamer/utils/inpaint_util.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: inpaint_util +# Copyright (c) 2023, XiMing Xing. +# License: MIT License + +import os +import pathlib + +import cv2 +import numpy as np +from omegaconf import OmegaConf +from tqdm import trange +import torch +from torch.utils.data._utils.collate import default_collate + + +def apply_lama_inpaint(predict_config, device): + # local import + from lama.saicinpainting.evaluation.utils import move_to_device + from lama.saicinpainting.evaluation.refinement import refine_predict + from lama.saicinpainting.training.data.datasets import make_default_val_dataset + from lama.saicinpainting.training.trainers import load_checkpoint + + try: + train_config_path = pathlib.Path(predict_config.model.path) / 'config.yaml' + train_config = OmegaConf.load(train_config_path) + + train_config.training_model.predict_only = True + train_config.visualizer.kind = 'noop' + + out_ext = predict_config.get('out_ext', '.png') + + checkpoint_path = os.path.join( + predict_config.model.path, 'models', predict_config.model.checkpoint + ) + model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu') + model.freeze() + + if not predict_config.get('refine', False): + model.to(device) + + if not predict_config.indir.endswith('/'): + predict_config.indir += '/' + + dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset) + for img_i in trange(len(dataset)): + mask_fname = dataset.mask_filenames[img_i] + cur_out_fname = os.path.join( + predict_config.outdir, + os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext + ) + os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True) + batch = default_collate([dataset[img_i]]) + + if predict_config.get('refine', False): + assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement" + # image unpadding is taken care of in the refiner, so that output image + # is same size as the input image + cur_res = refine_predict(batch, model, **predict_config.refiner) + cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() + else: + with torch.no_grad(): + batch = move_to_device(batch, device) + batch['mask'] = (batch['mask'] > 0) * 1 + batch = model(batch) + cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy() + unpad_to_size = batch.get('unpad_to_size', None) + if unpad_to_size is not None: + orig_height, orig_width = unpad_to_size + cur_res = cur_res[:orig_height, :orig_width] + + cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + cv2.imwrite(cur_out_fname, cur_res) + + except KeyboardInterrupt: + print('Interrupted by user') + except Exception as ex: + print(f'Prediction failed due to:') + print(f'{ex}') + import sys + sys.exit(1) diff --git a/svgdreamer/utils/misc.py b/svgdreamer/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe633ec4ae5c907d875527cf0aae19c721caf66 --- /dev/null +++ b/svgdreamer/utils/misc.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Description: misc +# Copyright (c) 2023, XiMing Xing. +# License: MPL-2.0 License + +from datetime import datetime +import random +import pathlib +from typing import Any, List, Dict, Union + +import omegaconf + +"""Add Type""" +AnyPath = Union[str, pathlib.Path, 'os.PathLike'] +AnyList = Union[omegaconf.ListConfig, List] +AnyDict = Union[omegaconf.DictConfig, Dict] + + +def render_batch_wrap(cfg: omegaconf.DictConfig, + seed_range: List, + pipeline: Any, + **pipe_args): + start_time = datetime.now() + for idx, seed in enumerate(seed_range): + cfg.seed = seed # update seed + print(f"\n-> [{idx}/{len(seed_range)}], " + f"current seed: {seed}, " + f"current time: {datetime.now() - start_time}\n") + pipe = pipeline(cfg) + pipe.painterly_rendering(**pipe_args) + + +def get_seed_range(srange: AnyList): + # random sampling without specifying a range + start_, end_ = 1, 1000000 + if srange is not None: # specify range sequential sampling + seed_range_ = list(srange) + assert len(seed_range_) == 2 and int(seed_range_[1]) > int(seed_range_[0]) + start_, end_ = int(seed_range_[0]), int(seed_range_[1]) + seed_range = [i for i in range(start_, end_)] + else: + # a list of lengths 1000 sampled from the range start_ to end_ (e.g.: [1, 1000000]) + numbers = list(range(start_, end_)) + seed_range = random.sample(numbers, k=1000) + return seed_range + + +def mkdir(dirs: List[pathlib.Path]): + for _dir in dirs: + _dir.mkdir(parents=True, exist_ok=True) diff --git a/svgdreamer/utils/plot.py b/svgdreamer/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..d181ba2031abc9cbe357e19de7ee41878e0d1bbf --- /dev/null +++ b/svgdreamer/utils/plot.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Author: ximing +# Copyright (c) 2023, XiMing Xing. +# License: MPL-2.0 License + +from typing import AnyStr, BinaryIO, Union +from PIL import Image +import pathlib + +import numpy as np +import matplotlib.pyplot as plt +import torch +from torchvision.utils import make_grid + +from .misc import AnyPath + + +def save_image(image_array: np.ndarray, fname: AnyPath): + image = np.transpose(image_array, (1, 2, 0)).astype(np.uint8) + pil_image = Image.fromarray(image) + pil_image.save(fname) + + +def plot_attn(attn: np.ndarray, + threshold_map: np.ndarray, + inputs: torch.Tensor, + inds: np.ndarray, + output_path: AnyPath): + # currently supports one image (and not a batch) + plt.figure(figsize=(10, 5)) + + plt.subplot(1, 3, 1) + main_im = make_grid(inputs, normalize=True, pad_value=2) + main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) + plt.imshow(main_im, interpolation='nearest') + plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') + plt.title("input img") + plt.axis("off") + + plt.subplot(1, 3, 2) + plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) + plt.title("attn map") + plt.axis("off") + + plt.subplot(1, 3, 3) + threshold_map_ = (threshold_map - threshold_map.min()) / \ + (threshold_map.max() - threshold_map.min()) + plt.imshow(np.nan_to_num(threshold_map_), interpolation='nearest', vmin=0, vmax=1) + plt.title("prob softmax") + plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') + plt.axis("off") + + plt.tight_layout() + plt.savefig(output_path) + plt.close() + + +def plot_couple(input_1: torch.Tensor, + input_2: torch.Tensor, + step: int, + output_dir: str, + fname: AnyPath, # file name + prompt: str = '', # text prompt as image tile + pad_value: float = 0, + dpi: int = 300): + if input_1.shape != input_2.shape: + raise ValueError("inputs and outputs must have the same dimensions") + + plt.figure() + plt.subplot(1, 2, 1) # nrows=1, ncols=2, index=1 + grid = make_grid(input_1, normalize=True, pad_value=pad_value) + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + plt.imshow(ndarr) + plt.axis("off") + plt.title("Input") + + plt.subplot(1, 2, 2) # nrows=1, ncols=2, index=2 + grid = make_grid(input_2, normalize=True, pad_value=pad_value) + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + plt.imshow(ndarr) + plt.axis("off") + plt.title(f"Rendering - {step} steps") + + def insert_newline(string, point=9): + # split by blank + words = string.split() + if len(words) <= point: + return string + + word_chunks = [words[i:i + point] for i in range(0, len(words), point)] + new_string = "\n".join(" ".join(chunk) for chunk in word_chunks) + return new_string + + plt.suptitle(insert_newline(prompt), fontsize=10) + + plt.tight_layout() + plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi) + plt.close() + + +def plot_img(inputs: torch.Tensor, + output_dir: AnyStr, + fname: AnyPath, # file name + pad_value: float = 0): + assert torch.is_tensor(inputs), f"The input must be tensor type, but got {type(inputs)}" + + grid = make_grid(inputs, normalize=True, pad_value=pad_value) + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + + plt.imshow(ndarr) + plt.axis("off") + plt.tight_layout() + plt.close() + + im = Image.fromarray(ndarr) + im.save(f"{output_dir}/{fname}.png") + + +def plot_img_title(inputs: torch.Tensor, + title: str, + output_dir: AnyStr, + fname: AnyPath, # file name + pad_value: float = 0, + dpi: int = 500): + assert torch.is_tensor(inputs), f"The input must be tensor type, but got {type(inputs)}" + + grid = make_grid(inputs, normalize=True, pad_value=pad_value) + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + plt.imshow(ndarr) + plt.axis("off") + plt.title(f"{title}") + plt.savefig(f"{output_dir}/{fname}.png", dpi=dpi) + plt.close()