rahul7star commited on
Commit
8dfdf4d
·
verified ·
1 Parent(s): 9ca902e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -60
app.py CHANGED
@@ -35,22 +35,7 @@ hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/
35
  hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
36
  print("Downloads complete.")
37
 
38
-
39
-
40
-
41
- LANDSCAPE_WIDTH = 832
42
- LANDSCAPE_HEIGHT = 480
43
- MAX_SEED = np.iinfo(np.int32).max
44
-
45
- FIXED_FPS = 16
46
- MIN_FRAMES_MODEL = 8
47
- MAX_FRAMES_MODEL = 81
48
-
49
- MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
50
- MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
51
-
52
-
53
-
54
 
55
  # --- Image Processing Functions ---
56
  def calculate_video_dimensions(width, height, max_size=832, min_size=480):
@@ -283,65 +268,142 @@ model_management.load_models_gpu([
283
  loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
284
  ])
285
  print("All models loaded successfully!")
286
- import time
287
- import gradio as gr
288
- import tempfile
289
- import torch
290
- import random
291
- import spaces
292
-
293
- # --- Dynamic GPU duration logic ---
294
- def get_duration(
295
- start_image_pil,
296
- end_image_pil,
297
- prompt,
298
- negative_prompt,
299
- duration_seconds,
300
- progress,
301
- ):
302
- # 15ms per step → just an example
303
- calc_time = steps * 15
304
- print(f"[GPU Duration Estimate] {calc_time} sec for {steps} steps")
305
- return min(calc_time, 300) # hard cap for safety
306
-
307
 
308
  # --- Main Video Generation Logic ---
309
- @spaces.GPU(duration=get_duration)
310
  def generate_video(
311
  start_image_pil,
312
  end_image_pil,
313
  prompt,
314
  negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
315
- duration_seconds=duration_seconds,
316
  progress=gr.Progress(track_tqdm=True)
317
  ):
318
  """
319
  The main function to generate a video based on user inputs.
320
  This function is called every time the user clicks the 'Generate' button.
321
  """
322
- start_time = time.time()
323
  FPS = 16
324
- duration = int(FPS * duration_seconds) # convert seconds → frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
- # --- Your existing video gen code continues here ---
327
- # (I trimmed it for brevity, leave all nodes/patches/workflow unchanged)
 
 
 
 
 
 
 
 
 
328
 
329
- # final save video logic...
330
- elapsed = time.time() - start_time
331
- print(f"[GPU Time Log] Video generated in {elapsed:.2f} sec")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- return f"output/{save_result['ui']['images'][0]['filename']}"
334
 
335
 
336
- # --- Gradio UI ---
337
  css = '''
338
  .fillable{max-width: 1100px !important}
339
  .dark .progress-text {color: white}
340
  '''
341
  with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
342
  gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
343
- gr.Markdown("GPU time is dynamically calculated. Max video duration: **5 seconds**.")
344
-
345
  with gr.Row():
346
  with gr.Column():
347
  with gr.Group():
@@ -350,14 +412,14 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
350
  end_image = gr.Image(type="pil", label="End Frame")
351
 
352
  prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
353
-
354
- # Duration bar (1–5 seconds)
355
- duration_seconds = gr.Slider(
356
- minimum=1, maximum=5, value=2, step=1,
357
- label="Video Duration (seconds)"
358
- )
359
-
360
- with gr.Accordion("Advanced Settings", open=False, visible=False):
361
  negative_prompt = gr.Textbox(
362
  label="Negative Prompt",
363
  value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
@@ -371,7 +433,7 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
371
 
372
  generate_button.click(
373
  fn=generate_video,
374
- inputs=[start_image, end_image, prompt, negative_prompt, duration_seconds],
375
  outputs=output_video
376
  )
377
 
@@ -388,4 +450,4 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
388
  )
389
 
390
  if __name__ == "__main__":
391
- app.launch(share=True)
 
35
  hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
36
  print("Downloads complete.")
37
 
38
+ model_management.vram_state = model_management.VRAMState.HIGH_VRAM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # --- Image Processing Functions ---
41
  def calculate_video_dimensions(width, height, max_size=832, min_size=480):
 
268
  loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
269
  ])
270
  print("All models loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  # --- Main Video Generation Logic ---
273
+ @spaces.GPU(duration=120)
274
  def generate_video(
275
  start_image_pil,
276
  end_image_pil,
277
  prompt,
278
  negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
279
+ duration=33,
280
  progress=gr.Progress(track_tqdm=True)
281
  ):
282
  """
283
  The main function to generate a video based on user inputs.
284
  This function is called every time the user clicks the 'Generate' button.
285
  """
 
286
  FPS = 16
287
+
288
+ # Process images: resize and crop second image to match first
289
+ # The first image determines the dimensions
290
+ processed_start_image = start_image_pil.copy()
291
+ processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil)
292
+
293
+ # Calculate video dimensions based on the first image
294
+ video_width, video_height = calculate_video_dimensions(
295
+ processed_start_image.width,
296
+ processed_start_image.height
297
+ )
298
+
299
+ print(f"Input image size: {processed_start_image.width}x{processed_start_image.height}")
300
+ print(f"Video dimensions: {video_width}x{video_height}")
301
+
302
+ clip = MODELS_AND_NODES["clip"]
303
+ vae = MODELS_AND_NODES["vae"]
304
+ model_low_noise = MODELS_AND_NODES["model_low_noise"]
305
+ model_high_noise = MODELS_AND_NODES["model_high_noise"]
306
+ clip_vision = MODELS_AND_NODES["clip_vision"]
307
+
308
+ cliptextencode = MODELS_AND_NODES["CLIPTextEncode"]
309
+ loadimage = MODELS_AND_NODES["LoadImage"]
310
+ clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"]
311
+ modelsamplingsd3 = MODELS_AND_NODES["ModelSamplingSD3"]
312
+ pathchsageattentionkj = MODELS_AND_NODES["PathchSageAttentionKJ"]
313
+ wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"]
314
+ ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"]
315
+ vaedecode = MODELS_AND_NODES["VAEDecode"]
316
+ createvideo = MODELS_AND_NODES["CreateVideo"]
317
+ savevideo = MODELS_AND_NODES["SaveVideo"]
318
+
319
+ # Save processed images to temporary files
320
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as start_file, \
321
+ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as end_file:
322
+ processed_start_image.save(start_file.name)
323
+ processed_end_image.save(end_file.name)
324
+ start_image_path = start_file.name
325
+ end_image_path = end_file.name
326
+
327
+ with torch.inference_mode():
328
+ progress(0.1, desc="Encoding text and images...")
329
+ # --- Workflow execution ---
330
+ positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0))
331
+ negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0))
332
+
333
+ start_image_loaded = loadimage.load_image(image=start_image_path)
334
+ end_image_loaded = loadimage.load_image(image=end_image_path)
335
+
336
+ clip_vision_encoded_start = clipvisionencode.encode(
337
+ crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0)
338
+ )
339
+ clip_vision_encoded_end = clipvisionencode.encode(
340
+ crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0)
341
+ )
342
 
343
+ progress(0.2, desc="Preparing initial latents...")
344
+ initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
345
+ width=video_width, height=video_height, length=duration, batch_size=1,
346
+ positive=get_value_at_index(positive_conditioning, 0),
347
+ negative=get_value_at_index(negative_conditioning, 0),
348
+ vae=get_value_at_index(vae, 0),
349
+ clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0),
350
+ clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0),
351
+ start_image=get_value_at_index(start_image_loaded, 0),
352
+ end_image=get_value_at_index(end_image_loaded, 0),
353
+ )
354
 
355
+ progress(0.3, desc="Patching models...")
356
+ model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_noise, 0))
357
+ model_low_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0))
358
+
359
+ model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_noise, 0))
360
+ model_high_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0))
361
+
362
+ progress(0.5, desc="Running KSampler (Step 1/2)...")
363
+ latent_step1 = ksampleradvanced.sample(
364
+ add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
365
+ sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
366
+ return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0),
367
+ positive=get_value_at_index(initial_latents, 0),
368
+ negative=get_value_at_index(initial_latents, 1),
369
+ latent_image=get_value_at_index(initial_latents, 2),
370
+ )
371
+
372
+ progress(0.7, desc="Running KSampler (Step 2/2)...")
373
+ latent_step2 = ksampleradvanced.sample(
374
+ add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
375
+ sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
376
+ return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0),
377
+ positive=get_value_at_index(initial_latents, 0),
378
+ negative=get_value_at_index(initial_latents, 1),
379
+ latent_image=get_value_at_index(latent_step1, 0),
380
+ )
381
+
382
+ progress(0.8, desc="Decoding VAE...")
383
+ decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0))
384
+
385
+ progress(0.9, desc="Creating and saving video...")
386
+ video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0))
387
+
388
+ # Save the video to ComfyUI's output directory
389
+ save_result = savevideo.save_video(
390
+ filename_prefix="GradioVideo", format="mp4", codec="h264",
391
+ video=get_value_at_index(video_data, 0),
392
+ )
393
+
394
+ progress(1.0, desc="Done!")
395
+ return f"output/{save_result['ui']['images'][0]['filename']}"
396
 
 
397
 
398
 
 
399
  css = '''
400
  .fillable{max-width: 1100px !important}
401
  .dark .progress-text {color: white}
402
  '''
403
  with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
404
  gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
405
+ gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) and the [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA on ZeroGPU")
406
+
407
  with gr.Row():
408
  with gr.Column():
409
  with gr.Group():
 
412
  end_image = gr.Image(type="pil", label="End Frame")
413
 
414
  prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
415
+
416
+ with gr.Accordion("Advanced Settings", open=False, visible=True):
417
+ duration = gr.Radio(
418
+ [("Short (2s)", 33), ("Mid (4s)", 66)],
419
+ value=33,
420
+ label="Video Duration",
421
+ visible=False
422
+ )
423
  negative_prompt = gr.Textbox(
424
  label="Negative Prompt",
425
  value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
 
433
 
434
  generate_button.click(
435
  fn=generate_video,
436
+ inputs=[start_image, end_image, prompt, negative_prompt, duration],
437
  outputs=output_video
438
  )
439
 
 
450
  )
451
 
452
  if __name__ == "__main__":
453
+ app.launch(share=True)