Spaces:
Runtime error
Runtime error
remove trigger check, as it messes up the time estimate
Browse files- app.py +14 -23
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -89,6 +89,8 @@ def generate(
|
|
| 89 |
def check_if_compiled(image, inference_steps, height, width, num_frames, message):
|
| 90 |
height = int(height)
|
| 91 |
width = int(width)
|
|
|
|
|
|
|
| 92 |
hint_image = image
|
| 93 |
if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations:
|
| 94 |
return ''
|
|
@@ -100,7 +102,7 @@ if _preheat:
|
|
| 100 |
generate(
|
| 101 |
prompt = 'preheating the oven',
|
| 102 |
neg_prompt = '',
|
| 103 |
-
image =
|
| 104 |
inference_steps = 20,
|
| 105 |
cfg = 12.0,
|
| 106 |
seed = 0
|
|
@@ -109,7 +111,7 @@ if _preheat:
|
|
| 109 |
dada = generate(
|
| 110 |
prompt = 'Entertaining the guests with sailor songs played on an old harmonium.',
|
| 111 |
neg_prompt = '',
|
| 112 |
-
image =
|
| 113 |
inference_steps = 20,
|
| 114 |
cfg = 12.0,
|
| 115 |
seed = 0
|
|
@@ -227,29 +229,22 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
|
|
| 227 |
)
|
| 228 |
with gr.Column(variant = variant):
|
| 229 |
#no_gpu = gr.Markdown('**Until a GPU is assigned expect extremely long runtimes up to 1h+**')
|
| 230 |
-
will_trigger = gr.Markdown('')
|
| 231 |
-
patience = gr.Markdown('')
|
| 232 |
image_output = gr.Image(
|
| 233 |
label = 'Output',
|
| 234 |
value = 'example.webp',
|
| 235 |
interactive = False
|
| 236 |
)
|
| 237 |
-
trigger_inputs =
|
| 238 |
-
trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.')
|
| 239 |
-
height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 240 |
-
width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 241 |
-
num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 242 |
-
image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 243 |
-
inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 244 |
-
will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value)
|
| 245 |
ev = submit_button.click(
|
| 246 |
-
fn = partial(
|
| 247 |
-
check_if_compiled,
|
| 248 |
-
message = 'Please be patient. The model has to be compiled with current parameters.'
|
| 249 |
-
),
|
| 250 |
-
inputs = trigger_inputs,
|
| 251 |
-
outputs = patience
|
| 252 |
-
).then(
|
| 253 |
fn = generate,
|
| 254 |
inputs = [
|
| 255 |
prompt_input,
|
|
@@ -265,10 +260,6 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
|
|
| 265 |
],
|
| 266 |
outputs = image_output,
|
| 267 |
postprocess = False
|
| 268 |
-
).then(
|
| 269 |
-
fn = trigger_check_fun,
|
| 270 |
-
inputs = trigger_inputs,
|
| 271 |
-
outputs = will_trigger
|
| 272 |
)
|
| 273 |
#cancel_button.click(fn = lambda: None, cancels = ev)
|
| 274 |
|
|
|
|
| 89 |
def check_if_compiled(image, inference_steps, height, width, num_frames, message):
|
| 90 |
height = int(height)
|
| 91 |
width = int(width)
|
| 92 |
+
height = (height // 64) * 64
|
| 93 |
+
width = (width // 64) * 64
|
| 94 |
hint_image = image
|
| 95 |
if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations:
|
| 96 |
return ''
|
|
|
|
| 102 |
generate(
|
| 103 |
prompt = 'preheating the oven',
|
| 104 |
neg_prompt = '',
|
| 105 |
+
image = None,
|
| 106 |
inference_steps = 20,
|
| 107 |
cfg = 12.0,
|
| 108 |
seed = 0
|
|
|
|
| 111 |
dada = generate(
|
| 112 |
prompt = 'Entertaining the guests with sailor songs played on an old harmonium.',
|
| 113 |
neg_prompt = '',
|
| 114 |
+
image = Image.new('RGB', size = (512, 512), color = (0, 0, 0)),
|
| 115 |
inference_steps = 20,
|
| 116 |
cfg = 12.0,
|
| 117 |
seed = 0
|
|
|
|
| 229 |
)
|
| 230 |
with gr.Column(variant = variant):
|
| 231 |
#no_gpu = gr.Markdown('**Until a GPU is assigned expect extremely long runtimes up to 1h+**')
|
| 232 |
+
#will_trigger = gr.Markdown('')
|
| 233 |
+
patience = gr.Markdown('**Please be patient. The model might have to compile with current parameters.**')
|
| 234 |
image_output = gr.Image(
|
| 235 |
label = 'Output',
|
| 236 |
value = 'example.webp',
|
| 237 |
interactive = False
|
| 238 |
)
|
| 239 |
+
#trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input ]
|
| 240 |
+
#trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.')
|
| 241 |
+
#height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 242 |
+
#width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 243 |
+
#num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 244 |
+
#image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 245 |
+
#inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
|
| 246 |
+
#will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value)
|
| 247 |
ev = submit_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
fn = generate,
|
| 249 |
inputs = [
|
| 250 |
prompt_input,
|
|
|
|
| 260 |
],
|
| 261 |
outputs = image_output,
|
| 262 |
postprocess = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
)
|
| 264 |
#cancel_button.click(fn = lambda: None, cancels = ev)
|
| 265 |
|
requirements.txt
CHANGED
|
@@ -6,5 +6,5 @@ einops
|
|
| 6 |
-f https://download.pytorch.org/whl/cpu/torch
|
| 7 |
torch[cpu]
|
| 8 |
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 9 |
-
jax[cuda11_cudnn82] #jax[cuda11_cudnn86] #jax[cuda11_cudnn805]
|
| 10 |
flax
|
|
|
|
| 6 |
-f https://download.pytorch.org/whl/cpu/torch
|
| 7 |
torch[cpu]
|
| 8 |
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 9 |
+
jax[cuda11_pip] #jax[cuda11_cudnn82] #jax[cuda11_cudnn86] #jax[cuda11_cudnn805]
|
| 10 |
flax
|