Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -150,37 +150,46 @@ def call(
|
|
150 |
|
151 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
152 |
extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)
|
153 |
-
|
154 |
# 7. Prepare added time ids & embeddings
|
155 |
add_text_embeds = pooled_prompt_embeds
|
156 |
add_text2_embeds = pooled_prompt2_embeds
|
157 |
-
|
158 |
-
# Default dtype if prompt_embeds or prompt2_embeds are None
|
159 |
-
default_dtype = torch.float32
|
160 |
-
|
161 |
-
# Check and set dtype for add_time_ids
|
162 |
-
dtype_for_add_time_ids = prompt_embeds.dtype if prompt_embeds is not None else default_dtype
|
163 |
add_time_ids = pipe._get_add_time_ids(
|
164 |
-
original_size, crops_coords_top_left, target_size, dtype=
|
165 |
)
|
166 |
-
|
167 |
-
# Check and set dtype for add_time2_ids
|
168 |
-
dtype_for_add_time2_ids = prompt2_embeds.dtype if prompt2_embeds is not None else default_dtype
|
169 |
add_time2_ids = pipe._get_add_time_ids(
|
170 |
-
original_size, crops_coords_top_left, target_size, dtype=
|
171 |
)
|
172 |
-
|
173 |
if negative_original_size is not None and negative_target_size is not None:
|
174 |
negative_add_time_ids = pipe._get_add_time_ids(
|
175 |
negative_original_size,
|
176 |
negative_crops_coords_top_left,
|
177 |
negative_target_size,
|
178 |
-
dtype=
|
179 |
)
|
180 |
else:
|
181 |
negative_add_time_ids = add_time_ids
|
182 |
negative_add_time2_ids = add_time2_ids
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
# 8. Denoising loop
|
185 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
|
186 |
|
|
|
150 |
|
151 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
152 |
extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)
|
153 |
+
|
154 |
# 7. Prepare added time ids & embeddings
|
155 |
add_text_embeds = pooled_prompt_embeds
|
156 |
add_text2_embeds = pooled_prompt2_embeds
|
157 |
+
|
|
|
|
|
|
|
|
|
|
|
158 |
add_time_ids = pipe._get_add_time_ids(
|
159 |
+
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
160 |
)
|
|
|
|
|
|
|
161 |
add_time2_ids = pipe._get_add_time_ids(
|
162 |
+
original_size, crops_coords_top_left, target_size, dtype=prompt2_embeds.dtype
|
163 |
)
|
164 |
+
|
165 |
if negative_original_size is not None and negative_target_size is not None:
|
166 |
negative_add_time_ids = pipe._get_add_time_ids(
|
167 |
negative_original_size,
|
168 |
negative_crops_coords_top_left,
|
169 |
negative_target_size,
|
170 |
+
dtype=prompt_embeds.dtype,
|
171 |
)
|
172 |
else:
|
173 |
negative_add_time_ids = add_time_ids
|
174 |
negative_add_time2_ids = add_time2_ids
|
175 |
|
176 |
+
if do_classifier_free_guidance:
|
177 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
178 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
179 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
180 |
+
|
181 |
+
prompt2_embeds = torch.cat([negative_prompt2_embeds, prompt2_embeds], dim=0)
|
182 |
+
add_text2_embeds = torch.cat([negative_pooled_prompt2_embeds, add_text2_embeds], dim=0)
|
183 |
+
add_time2_ids = torch.cat([negative_add_time2_ids, add_time2_ids], dim=0)
|
184 |
+
|
185 |
+
prompt_embeds = prompt_embeds.to(device)
|
186 |
+
add_text_embeds = add_text_embeds.to(device)
|
187 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
188 |
+
|
189 |
+
prompt2_embeds = prompt2_embeds.to(device)
|
190 |
+
add_text2_embeds = add_text2_embeds.to(device)
|
191 |
+
add_time2_ids = add_time2_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
192 |
+
|
193 |
# 8. Denoising loop
|
194 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
|
195 |
|