linoyts HF staff commited on
Commit
3409336
1 Parent(s): 6cf4fca

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +80 -8
clip_slider_pipeline.py CHANGED
@@ -210,8 +210,6 @@ class CLIPSliderXL(CLIPSlider):
210
  correlation_weight_factor = 1.0,
211
  avg_diff = None,
212
  avg_diff_2nd = None,
213
- init_latents = None, # inversion
214
- zs = None, # inversion
215
  **pipeline_kwargs
216
  ):
217
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -289,14 +287,88 @@ class CLIPSliderXL(CLIPSlider):
289
  print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
290
  torch.manual_seed(seed)
291
  start_time = time.time()
292
- if init_latents is not None: # inversion
293
- image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
294
- avg_diff=avg_diff, avg_diff_2=avg_diff2, scale=scale,
295
- **pipeline_kwargs).images[0]
296
- else:
297
- image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
298
  **pipeline_kwargs).images[0]
299
  end_time = time.time()
300
  print(f"generation time - pipe: {end_time - start_time:.2f} ms")
301
 
302
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  correlation_weight_factor = 1.0,
211
  avg_diff = None,
212
  avg_diff_2nd = None,
 
 
213
  **pipeline_kwargs
214
  ):
215
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
287
  print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
288
  torch.manual_seed(seed)
289
  start_time = time.time()
290
+ image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
 
 
 
 
 
291
  **pipeline_kwargs).images[0]
292
  end_time = time.time()
293
  print(f"generation time - pipe: {end_time - start_time:.2f} ms")
294
 
295
  return image
296
+
297
+ class CLIPSliderXL_inv(CLIPSlider):
298
+
299
+ def find_latent_direction(self,
300
+ target_word:str,
301
+ opposite:str,
302
+ num_iterations: int = None):
303
+
304
+ # lets identify a latent direction by taking differences between opposites
305
+ # target_word = "happy"
306
+ # opposite = "sad"
307
+ if num_iterations is not None:
308
+ iterations = num_iterations
309
+ else:
310
+ iterations = self.iterations
311
+
312
+ with torch.no_grad():
313
+ positives = []
314
+ negatives = []
315
+ positives2 = []
316
+ negatives2 = []
317
+ for i in tqdm(range(iterations)):
318
+ medium = random.choice(MEDIUMS)
319
+ subject = random.choice(SUBJECTS)
320
+ pos_prompt = f"a {medium} of a {target_word} {subject}"
321
+ neg_prompt = f"a {medium} of a {opposite} {subject}"
322
+
323
+ pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
324
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
325
+ neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
326
+ max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
327
+ pos = self.pipe.text_encoder(pos_toks).pooler_output
328
+ neg = self.pipe.text_encoder(neg_toks).pooler_output
329
+ positives.append(pos)
330
+ negatives.append(neg)
331
+
332
+ pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
333
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
334
+ neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
335
+ max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
336
+ pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
337
+ neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
338
+ positives2.append(pos2)
339
+ negatives2.append(neg2)
340
+
341
+ positives = torch.cat(positives, dim=0)
342
+ negatives = torch.cat(negatives, dim=0)
343
+ diffs = positives - negatives
344
+ avg_diff = diffs.mean(0, keepdim=True)
345
+
346
+ positives2 = torch.cat(positives2, dim=0)
347
+ negatives2 = torch.cat(negatives2, dim=0)
348
+ diffs2 = positives2 - negatives2
349
+ avg_diff2 = diffs2.mean(0, keepdim=True)
350
+ return (avg_diff, avg_diff2)
351
+
352
+ def generate(self,
353
+ prompt = "a photo of a house",
354
+ scale = 2,
355
+ scale_2nd = 2,
356
+ seed = 15,
357
+ only_pooler = False,
358
+ normalize_scales = False,
359
+ correlation_weight_factor = 1.0,
360
+ avg_diff=None,
361
+ avg_diff_2nd=None,
362
+ init_latents=None,
363
+ zs=None,
364
+ **pipeline_kwargs
365
+ ):
366
+
367
+ with torch.no_grad():
368
+ torch.manual_seed(seed)
369
+ images = self.pipe(editing_prompt=prompt, init_latents=init_latents, zs=zs,
370
+ avg_diff=avg_diff[0], avg_diff_2=avg_diff[1],
371
+ scale=scale,
372
+ **pipeline_kwargs).images
373
+
374
+ return images