Spaces:
Running
on
Zero
Running
on
Zero
Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +30 -20
clip_slider_pipeline.py
CHANGED
@@ -10,17 +10,20 @@ class CLIPSlider:
|
|
10 |
self,
|
11 |
sd_pipe,
|
12 |
device: torch.device,
|
13 |
-
target_word: str,
|
14 |
-
opposite: str,
|
15 |
target_word_2nd: str = "",
|
16 |
opposite_2nd: str = "",
|
17 |
iterations: int = 300,
|
18 |
):
|
19 |
|
20 |
self.device = device
|
21 |
-
self.pipe = sd_pipe.to(self.device)
|
22 |
self.iterations = iterations
|
23 |
-
|
|
|
|
|
|
|
24 |
if target_word_2nd != "" or opposite_2nd != "":
|
25 |
self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
|
26 |
else:
|
@@ -29,12 +32,15 @@ class CLIPSlider:
|
|
29 |
|
30 |
def find_latent_direction(self,
|
31 |
target_word:str,
|
32 |
-
opposite:str):
|
33 |
|
34 |
# lets identify a latent direction by taking differences between opposites
|
35 |
# target_word = "happy"
|
36 |
# opposite = "sad"
|
37 |
-
|
|
|
|
|
|
|
38 |
|
39 |
with torch.no_grad():
|
40 |
positives = []
|
@@ -70,6 +76,8 @@ class CLIPSlider:
|
|
70 |
only_pooler = False,
|
71 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
72 |
correlation_weight_factor = 1.0,
|
|
|
|
|
73 |
**pipeline_kwargs
|
74 |
):
|
75 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
@@ -80,14 +88,14 @@ class CLIPSlider:
|
|
80 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
81 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
82 |
|
83 |
-
if
|
84 |
denominator = abs(scale) + abs(scale_2nd)
|
85 |
scale = scale / denominator
|
86 |
scale_2nd = scale_2nd / denominator
|
87 |
if only_pooler:
|
88 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] +
|
89 |
-
if
|
90 |
-
prompt_embeds[:, toks.argmax()] +=
|
91 |
else:
|
92 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
93 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
@@ -99,9 +107,9 @@ class CLIPSlider:
|
|
99 |
|
100 |
# weights = torch.sigmoid((weights-0.5)*7)
|
101 |
prompt_embeds = prompt_embeds + (
|
102 |
-
weights *
|
103 |
-
if
|
104 |
-
prompt_embeds += weights *
|
105 |
|
106 |
|
107 |
torch.manual_seed(seed)
|
@@ -399,6 +407,8 @@ class T5SliderFlux(CLIPSlider):
|
|
399 |
only_pooler = False,
|
400 |
normalize_scales = False,
|
401 |
correlation_weight_factor = 1.0,
|
|
|
|
|
402 |
**pipeline_kwargs
|
403 |
):
|
404 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
@@ -438,14 +448,14 @@ class T5SliderFlux(CLIPSlider):
|
|
438 |
dtype = self.pipe.text_encoder_2.dtype
|
439 |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
|
440 |
print("1", prompt_embeds.shape)
|
441 |
-
if
|
442 |
denominator = abs(scale) + abs(scale_2nd)
|
443 |
scale = scale / denominator
|
444 |
scale_2nd = scale_2nd / denominator
|
445 |
if only_pooler:
|
446 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] +
|
447 |
-
if
|
448 |
-
prompt_embeds[:, toks.argmax()] +=
|
449 |
else:
|
450 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
451 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
@@ -457,11 +467,11 @@ class T5SliderFlux(CLIPSlider):
|
|
457 |
|
458 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
459 |
prompt_embeds = prompt_embeds + (
|
460 |
-
weights *
|
461 |
print("2", prompt_embeds.shape)
|
462 |
-
if
|
463 |
prompt_embeds += (
|
464 |
-
weights *
|
465 |
|
466 |
torch.manual_seed(seed)
|
467 |
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
|
|
10 |
self,
|
11 |
sd_pipe,
|
12 |
device: torch.device,
|
13 |
+
target_word: str = "",
|
14 |
+
opposite: str = "",
|
15 |
target_word_2nd: str = "",
|
16 |
opposite_2nd: str = "",
|
17 |
iterations: int = 300,
|
18 |
):
|
19 |
|
20 |
self.device = device
|
21 |
+
self.pipe = sd_pipe.to(self.device, torch.float16)
|
22 |
self.iterations = iterations
|
23 |
+
if target_word != "" or opposite != "":
|
24 |
+
self.avg_diff = self.find_latent_direction(target_word, opposite)
|
25 |
+
else:
|
26 |
+
self.avg_diff = None
|
27 |
if target_word_2nd != "" or opposite_2nd != "":
|
28 |
self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
|
29 |
else:
|
|
|
32 |
|
33 |
def find_latent_direction(self,
|
34 |
target_word:str,
|
35 |
+
opposite:str, num_iterations: int = None):
|
36 |
|
37 |
# lets identify a latent direction by taking differences between opposites
|
38 |
# target_word = "happy"
|
39 |
# opposite = "sad"
|
40 |
+
if num_iterations is not None:
|
41 |
+
iterations = num_iterations
|
42 |
+
else:
|
43 |
+
iterations = self.iterations
|
44 |
|
45 |
with torch.no_grad():
|
46 |
positives = []
|
|
|
76 |
only_pooler = False,
|
77 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
78 |
correlation_weight_factor = 1.0,
|
79 |
+
avg_diff = None,
|
80 |
+
avg_diff_2nd = None,
|
81 |
**pipeline_kwargs
|
82 |
):
|
83 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
88 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
89 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
90 |
|
91 |
+
if avg_diff_2nd and normalize_scales:
|
92 |
denominator = abs(scale) + abs(scale_2nd)
|
93 |
scale = scale / denominator
|
94 |
scale_2nd = scale_2nd / denominator
|
95 |
if only_pooler:
|
96 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
|
97 |
+
if avg_diff_2nd:
|
98 |
+
prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
|
99 |
else:
|
100 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
101 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
|
107 |
|
108 |
# weights = torch.sigmoid((weights-0.5)*7)
|
109 |
prompt_embeds = prompt_embeds + (
|
110 |
+
weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
111 |
+
if avg_diff_2nd:
|
112 |
+
prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
|
113 |
|
114 |
|
115 |
torch.manual_seed(seed)
|
|
|
407 |
only_pooler = False,
|
408 |
normalize_scales = False,
|
409 |
correlation_weight_factor = 1.0,
|
410 |
+
avg_diff = None,
|
411 |
+
avg_diff_2nd = None,
|
412 |
**pipeline_kwargs
|
413 |
):
|
414 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
448 |
dtype = self.pipe.text_encoder_2.dtype
|
449 |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
|
450 |
print("1", prompt_embeds.shape)
|
451 |
+
if avg_diff_2nd and normalize_scales:
|
452 |
denominator = abs(scale) + abs(scale_2nd)
|
453 |
scale = scale / denominator
|
454 |
scale_2nd = scale_2nd / denominator
|
455 |
if only_pooler:
|
456 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
|
457 |
+
if avg_diff_2nd:
|
458 |
+
prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
|
459 |
else:
|
460 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
461 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
|
467 |
|
468 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
469 |
prompt_embeds = prompt_embeds + (
|
470 |
+
weights * avg_diff * scale)
|
471 |
print("2", prompt_embeds.shape)
|
472 |
+
if avg_diff_2nd:
|
473 |
prompt_embeds += (
|
474 |
+
weights * avg_diff_2nd * scale_2nd)
|
475 |
|
476 |
torch.manual_seed(seed)
|
477 |
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|