Upload pipeline.py with huggingface_hub
Browse files- pipeline.py +7 -10
pipeline.py
CHANGED
@@ -124,10 +124,12 @@ class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
|
|
124 |
# at least PRESENT
|
125 |
self.text_encoder = self.text_encoder_2 = None
|
126 |
|
127 |
-
#
|
128 |
-
# Encode a text prompt
|
129 |
-
#
|
130 |
-
#
|
|
|
|
|
131 |
def encode_prompt(
|
132 |
self,
|
133 |
prompt,
|
@@ -146,7 +148,7 @@ class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
|
|
146 |
where B = batch * num_images_per_prompt
|
147 |
"""
|
148 |
|
149 |
-
# --- helper to tokenize on the pipeline
|
150 |
def _tok(text: str):
|
151 |
tok_out = self.tokenizer(
|
152 |
text,
|
@@ -180,9 +182,4 @@ class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
|
|
180 |
else:
|
181 |
tok_neg = pool_neg = None
|
182 |
|
183 |
-
# ----------------- final ordered return --------------------------
|
184 |
-
# 1) positive token embeddings
|
185 |
-
# 2) negative token embeddings (or None)
|
186 |
-
# 3) positive pooled embeddings
|
187 |
-
# 4) negative pooled embeddings (or None)
|
188 |
return tok_pos, tok_neg, pool_pos, pool_neg
|
|
|
124 |
# at least PRESENT
|
125 |
self.text_encoder = self.text_encoder_2 = None
|
126 |
|
127 |
+
# ------------------------------------------------------------------------
|
128 |
+
# Encode a text prompt
|
129 |
+
# Use + 4096 => 2048 projection for standard embeds, but
|
130 |
+
# 4096 => 1280 for pooled embeds, because that's what the unet requires.
|
131 |
+
# Returns exactly four tensors in the order SDXL's __call__ expects.
|
132 |
+
# ------------------------------------------------------------------------
|
133 |
def encode_prompt(
|
134 |
self,
|
135 |
prompt,
|
|
|
148 |
where B = batch * num_images_per_prompt
|
149 |
"""
|
150 |
|
151 |
+
# --- helper to tokenize on the pipeline's device ----------------
|
152 |
def _tok(text: str):
|
153 |
tok_out = self.tokenizer(
|
154 |
text,
|
|
|
182 |
else:
|
183 |
tok_neg = pool_neg = None
|
184 |
|
|
|
|
|
|
|
|
|
|
|
185 |
return tok_pos, tok_neg, pool_pos, pool_neg
|