linoyts HF staff commited on
Commit
f5b8512
1 Parent(s): a0503f3

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +3 -4
clip_slider_pipeline.py CHANGED
@@ -236,14 +236,13 @@ class CLIPSliderXL(CLIPSlider):
236
  toks = text_inputs.input_ids
237
 
238
  prompt_embeds = text_encoder(
239
- toks.to(text_encoder.device, torch.float16),
240
  output_hidden_states=True,
241
  )
242
 
243
  # We are only ALWAYS interested in the pooled output of the final text encoder
244
- pooled_prompt_embeds = prompt_embeds[0]
245
-
246
- prompt_embeds = prompt_embeds.hidden_states[-2]
247
  print("prompt_embeds.dtype",prompt_embeds.dtype)
248
  if avg_diff_2nd and normalize_scales:
249
  denominator = abs(scale) + abs(scale_2nd)
 
236
  toks = text_inputs.input_ids
237
 
238
  prompt_embeds = text_encoder(
239
+ toks.to(text_encoder.device),
240
  output_hidden_states=True,
241
  )
242
 
243
  # We are only ALWAYS interested in the pooled output of the final text encoder
244
+ pooled_prompt_embeds = prompt_embeds[0].to(torch.float16)
245
+ prompt_embeds = prompt_embeds.hidden_states[-2].to(torch.float16)
 
246
  print("prompt_embeds.dtype",prompt_embeds.dtype)
247
  if avg_diff_2nd and normalize_scales:
248
  denominator = abs(scale) + abs(scale_2nd)