linoyts HF staff commited on
Commit
f2b4569
1 Parent(s): f5b8512

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +3 -3
clip_slider_pipeline.py CHANGED
@@ -241,8 +241,8 @@ class CLIPSliderXL(CLIPSlider):
241
  )
242
 
243
  # We are only ALWAYS interested in the pooled output of the final text encoder
244
- pooled_prompt_embeds = prompt_embeds[0].to(torch.float16)
245
- prompt_embeds = prompt_embeds.hidden_states[-2].to(torch.float16)
246
  print("prompt_embeds.dtype",prompt_embeds.dtype)
247
  if avg_diff_2nd and normalize_scales:
248
  denominator = abs(scale) + abs(scale_2nd)
@@ -286,7 +286,7 @@ class CLIPSliderXL(CLIPSlider):
286
  print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
287
  torch.manual_seed(seed)
288
  start_time = time.time()
289
- image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
290
  **pipeline_kwargs).images[0]
291
  end_time = time.time()
292
  print(f"generation time - pipe: {end_time - start_time:.2f} ms")
 
241
  )
242
 
243
  # We are only ALWAYS interested in the pooled output of the final text encoder
244
+ pooled_prompt_embeds = prompt_embeds[0]
245
+ prompt_embeds = prompt_embeds.hidden_states[-2]
246
  print("prompt_embeds.dtype",prompt_embeds.dtype)
247
  if avg_diff_2nd and normalize_scales:
248
  denominator = abs(scale) + abs(scale_2nd)
 
286
  print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
287
  torch.manual_seed(seed)
288
  start_time = time.time()
289
+ image = self.pipe(prompt_embeds=prompt_embeds.to(torch.float16), pooled_prompt_embeds=pooled_prompt_embeds.to(torch.float16),
290
  **pipeline_kwargs).images[0]
291
  end_time = time.time()
292
  print(f"generation time - pipe: {end_time - start_time:.2f} ms")