Spaces:
Running
on
Zero
Running
on
Zero
Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +80 -8
clip_slider_pipeline.py
CHANGED
@@ -210,8 +210,6 @@ class CLIPSliderXL(CLIPSlider):
|
|
210 |
correlation_weight_factor = 1.0,
|
211 |
avg_diff = None,
|
212 |
avg_diff_2nd = None,
|
213 |
-
init_latents = None, # inversion
|
214 |
-
zs = None, # inversion
|
215 |
**pipeline_kwargs
|
216 |
):
|
217 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
@@ -289,14 +287,88 @@ class CLIPSliderXL(CLIPSlider):
|
|
289 |
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
290 |
torch.manual_seed(seed)
|
291 |
start_time = time.time()
|
292 |
-
|
293 |
-
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
294 |
-
avg_diff=avg_diff, avg_diff_2=avg_diff2, scale=scale,
|
295 |
-
**pipeline_kwargs).images[0]
|
296 |
-
else:
|
297 |
-
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
298 |
**pipeline_kwargs).images[0]
|
299 |
end_time = time.time()
|
300 |
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|
301 |
|
302 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
correlation_weight_factor = 1.0,
|
211 |
avg_diff = None,
|
212 |
avg_diff_2nd = None,
|
|
|
|
|
213 |
**pipeline_kwargs
|
214 |
):
|
215 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
287 |
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
288 |
torch.manual_seed(seed)
|
289 |
start_time = time.time()
|
290 |
+
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
|
|
|
|
|
|
|
|
|
|
291 |
**pipeline_kwargs).images[0]
|
292 |
end_time = time.time()
|
293 |
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|
294 |
|
295 |
return image
|
296 |
+
|
297 |
+
class CLIPSliderXL_inv(CLIPSlider):
|
298 |
+
|
299 |
+
def find_latent_direction(self,
|
300 |
+
target_word:str,
|
301 |
+
opposite:str,
|
302 |
+
num_iterations: int = None):
|
303 |
+
|
304 |
+
# lets identify a latent direction by taking differences between opposites
|
305 |
+
# target_word = "happy"
|
306 |
+
# opposite = "sad"
|
307 |
+
if num_iterations is not None:
|
308 |
+
iterations = num_iterations
|
309 |
+
else:
|
310 |
+
iterations = self.iterations
|
311 |
+
|
312 |
+
with torch.no_grad():
|
313 |
+
positives = []
|
314 |
+
negatives = []
|
315 |
+
positives2 = []
|
316 |
+
negatives2 = []
|
317 |
+
for i in tqdm(range(iterations)):
|
318 |
+
medium = random.choice(MEDIUMS)
|
319 |
+
subject = random.choice(SUBJECTS)
|
320 |
+
pos_prompt = f"a {medium} of a {target_word} {subject}"
|
321 |
+
neg_prompt = f"a {medium} of a {opposite} {subject}"
|
322 |
+
|
323 |
+
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
324 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
325 |
+
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
326 |
+
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
327 |
+
pos = self.pipe.text_encoder(pos_toks).pooler_output
|
328 |
+
neg = self.pipe.text_encoder(neg_toks).pooler_output
|
329 |
+
positives.append(pos)
|
330 |
+
negatives.append(neg)
|
331 |
+
|
332 |
+
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
333 |
+
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
334 |
+
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
|
335 |
+
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
|
336 |
+
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
|
337 |
+
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
|
338 |
+
positives2.append(pos2)
|
339 |
+
negatives2.append(neg2)
|
340 |
+
|
341 |
+
positives = torch.cat(positives, dim=0)
|
342 |
+
negatives = torch.cat(negatives, dim=0)
|
343 |
+
diffs = positives - negatives
|
344 |
+
avg_diff = diffs.mean(0, keepdim=True)
|
345 |
+
|
346 |
+
positives2 = torch.cat(positives2, dim=0)
|
347 |
+
negatives2 = torch.cat(negatives2, dim=0)
|
348 |
+
diffs2 = positives2 - negatives2
|
349 |
+
avg_diff2 = diffs2.mean(0, keepdim=True)
|
350 |
+
return (avg_diff, avg_diff2)
|
351 |
+
|
352 |
+
def generate(self,
|
353 |
+
prompt = "a photo of a house",
|
354 |
+
scale = 2,
|
355 |
+
scale_2nd = 2,
|
356 |
+
seed = 15,
|
357 |
+
only_pooler = False,
|
358 |
+
normalize_scales = False,
|
359 |
+
correlation_weight_factor = 1.0,
|
360 |
+
avg_diff=None,
|
361 |
+
avg_diff_2nd=None,
|
362 |
+
init_latents=None,
|
363 |
+
zs=None,
|
364 |
+
**pipeline_kwargs
|
365 |
+
):
|
366 |
+
|
367 |
+
with torch.no_grad():
|
368 |
+
torch.manual_seed(seed)
|
369 |
+
images = self.pipe(editing_prompt=prompt, init_latents=init_latents, zs=zs,
|
370 |
+
avg_diff=avg_diff[0], avg_diff_2=avg_diff[1],
|
371 |
+
scale=scale,
|
372 |
+
**pipeline_kwargs).images
|
373 |
+
|
374 |
+
return images
|