import diffusers import torch import random from tqdm import tqdm from constants import SUBJECTS, MEDIUMS from PIL import Image import time class CLIPSlider: def __init__( self, sd_pipe, device: torch.device, target_word: str = "", opposite: str = "", target_word_2nd: str = "", opposite_2nd: str = "", iterations: int = 300, ): self.device = device self.pipe = sd_pipe.to(self.device, torch.float16) self.iterations = iterations if target_word != "" or opposite != "": self.avg_diff = self.find_latent_direction(target_word, opposite) else: self.avg_diff = None if target_word_2nd != "" or opposite_2nd != "": self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd) else: self.avg_diff_2nd = None def find_latent_direction(self, target_word:str, opposite:str, num_iterations: int = None): # lets identify a latent direction by taking differences between opposites # target_word = "happy" # opposite = "sad" if num_iterations is not None: iterations = num_iterations else: iterations = self.iterations with torch.no_grad(): positives = [] negatives = [] for i in tqdm(range(iterations)): medium = random.choice(MEDIUMS) subject = random.choice(SUBJECTS) pos_prompt = f"a {medium} of a {target_word} {subject}" neg_prompt = f"a {medium} of a {opposite} {subject}" pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() pos = self.pipe.text_encoder(pos_toks).pooler_output neg = self.pipe.text_encoder(neg_toks).pooler_output positives.append(pos) negatives.append(neg) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) diffs = positives - negatives avg_diff = diffs.mean(0, keepdim=True) return avg_diff def generate(self, prompt = "a photo of a house", scale = 2., scale_2nd = 0., # scale for the 2nd dim directions when avg_diff_2nd is not None seed = 15, only_pooler = False, normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None correlation_weight_factor = 1.0, avg_diff = None, avg_diff_2nd = None, **pipeline_kwargs ): # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true # if pooler token only [-4,4] work well with torch.no_grad(): toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state if avg_diff_2nd and normalize_scales: denominator = abs(scale) + abs(scale_2nd) scale = scale / denominator scale_2nd = scale_2nd / denominator if only_pooler: prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale if avg_diff_2nd: prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd else: normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) standard_weights = torch.ones_like(weights) weights = standard_weights + (weights - standard_weights) * correlation_weight_factor # weights = torch.sigmoid((weights-0.5)*7) prompt_embeds = prompt_embeds + ( weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) if avg_diff_2nd: prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd torch.manual_seed(seed) image = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images[0] return image def spectrum(self, prompt="a photo of a house", low_scale=-2, low_scale_2nd=-2, high_scale=2, high_scale_2nd=2, steps=5, seed=15, only_pooler=False, normalize_scales=False, correlation_weight_factor=1.0, **pipeline_kwargs ): images = [] for i in range(steps): scale = low_scale + (high_scale - low_scale) * i / (steps - 1) scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1) image = self.generate(prompt, scale, scale_2nd, seed, only_pooler, normalize_scales, correlation_weight_factor, **pipeline_kwargs) images.append(image[0]) canvas = Image.new('RGB', (640 * steps, 640)) for i, im in enumerate(images): canvas.paste(im, (640 * i, 0)) return canvas class CLIPSliderXL(CLIPSlider): def find_latent_direction(self, target_word:str, opposite:str, num_iterations: int = None): # lets identify a latent direction by taking differences between opposites # target_word = "happy" # opposite = "sad" if num_iterations is not None: iterations = num_iterations else: iterations = self.iterations with torch.no_grad(): positives = [] negatives = [] positives2 = [] negatives2 = [] for i in tqdm(range(iterations)): medium = random.choice(MEDIUMS) subject = random.choice(SUBJECTS) pos_prompt = f"a {medium} of a {target_word} {subject}" neg_prompt = f"a {medium} of a {opposite} {subject}" pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda() pos = self.pipe.text_encoder(pos_toks).pooler_output neg = self.pipe.text_encoder(neg_toks).pooler_output positives.append(pos) negatives.append(neg) pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda() neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda() pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds positives2.append(pos2) negatives2.append(neg2) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) diffs = positives - negatives avg_diff = diffs.mean(0, keepdim=True) positives2 = torch.cat(positives2, dim=0) negatives2 = torch.cat(negatives2, dim=0) diffs2 = positives2 - negatives2 avg_diff2 = diffs2.mean(0, keepdim=True) return (avg_diff, avg_diff2) def generate(self, prompt = "a photo of a house", scale = 2, scale_2nd = 2, seed = 15, only_pooler = False, normalize_scales = False, correlation_weight_factor = 1.0, avg_diff = None, avg_diff_2nd = None, init_latents = None, # inversion zs = None, # inversion **pipeline_kwargs ): # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true # if pooler token only [-4,4] work well start_time = time.time() text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2] tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2] with torch.no_grad(): # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.cuda() # prompt_embeds = pipe.text_encoder(toks).last_hidden_state prompt_embeds_list = [] for i, text_encoder in enumerate(text_encoders): tokenizer = tokenizers[i] text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) toks = text_inputs.input_ids prompt_embeds = text_encoder( toks.to(text_encoder.device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] print("prompt_embeds.dtype",prompt_embeds.dtype) if avg_diff_2nd and normalize_scales: denominator = abs(scale) + abs(scale_2nd) scale = scale / denominator scale_2nd = scale_2nd / denominator if only_pooler: prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale if avg_diff_2nd: prompt_embeds[:, toks.argmax()] += avg_diff_2nd[0] * scale_2nd else: normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T if i == 0: weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) standard_weights = torch.ones_like(weights) weights = standard_weights + (weights - standard_weights) * correlation_weight_factor prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) if avg_diff_2nd: prompt_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd) else: weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280) standard_weights = torch.ones_like(weights) weights = standard_weights + (weights - standard_weights) * correlation_weight_factor prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale) if avg_diff_2nd: prompt_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd) bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1).to(torch.float16) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1).to(torch.float16) end_time = time.time() print("prompt_embeds", prompt_embeds.dtype) print(f"generation time - before pipe: {end_time - start_time:.2f} ms") torch.manual_seed(seed) start_time = time.time() if init_latents is not None: # inversion image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, avg_diff=avg_diff, avg_diff_2=avg_diff2, scale=scale, **pipeline_kwargs).images[0] else: image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, **pipeline_kwargs).images[0] end_time = time.time() print(f"generation time - pipe: {end_time - start_time:.2f} ms") return image