rahul7star commited on
Commit
b811178
·
verified ·
1 Parent(s): aa2dfbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -35
app.py CHANGED
@@ -3,40 +3,61 @@ import os
3
  import subprocess
4
  import tempfile
5
  import glob
 
6
  from huggingface_hub import snapshot_download
7
  import gradio as gr
 
 
8
 
9
- # ---------------- Step 1: Download Model ----------------
10
  repo_id = "Wan-AI/Wan2.2-TI2V-5B"
11
  print(f"Downloading/loading checkpoints for {repo_id}...")
12
  ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
13
  print(f"Using checkpoints from {ckpt_dir}")
14
 
15
- # ---------------- Step 2: Duration Calculation ----------------
16
- def get_duration(prompt, size, duration_seconds, steps, progress):
17
- # progress param included for compatibility with @spaces.GPU but not used here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  try:
19
- h, w = size.lower().replace(" ", "").split("*")
20
- h, w = int(h), int(w)
 
 
 
 
21
  except Exception:
22
- h, w = 704, 1280
23
- duration = int(duration_seconds) * int(steps) * 2.25 + 5
24
- return duration
25
-
26
- # ---------------- Helper: Find latest .mp4 in cwd ----------------
27
- def find_latest_mp4():
28
- files = glob.glob("*.mp4")
29
- if not files:
30
- return None
31
- latest_file = max(files, key=os.path.getctime)
32
- return latest_file
33
 
34
- # ---------------- Step 3: Generation Functions ----------------
35
-
36
-
37
- def get_duration0(
38
- prompt, size, duration_seconds, steps, progress):
39
- """Calculate dynamic GPU duration based on parameters."""
40
  if duration_seconds >= 3:
41
  return 220
42
  elif steps > 35 and duration_seconds >= 2:
@@ -46,17 +67,31 @@ def get_duration0(
46
  else:
47
  return 90
48
 
49
- # --- 2. Gradio Inference Function ---
50
-
 
 
 
 
51
 
 
52
 
53
- @spaces.GPU(duration=get_duration0)
54
  def generate_t2v(prompt, size="1280*704", duration_seconds=5, steps=25, progress=gr.Progress(track_tqdm=True)):
55
  if not prompt.strip():
56
  return None, None, "Please enter a prompt."
57
 
58
  temp_dir = tempfile.mkdtemp()
59
 
 
 
 
 
 
 
 
 
 
60
  cmd = [
61
  "python", "generate.py",
62
  "--task", "ti2v-5B",
@@ -75,6 +110,8 @@ def generate_t2v(prompt, size="1280*704", duration_seconds=5, steps=25, progress
75
  except subprocess.CalledProcessError as e:
76
  return None, None, f"Error during T2V generation: {e}"
77
 
 
 
78
  video_file = find_latest_mp4()
79
  if video_file is None:
80
  return None, None, "Generation finished but no video file found."
@@ -82,7 +119,7 @@ def generate_t2v(prompt, size="1280*704", duration_seconds=5, steps=25, progress
82
  dest_path = os.path.join(temp_dir, os.path.basename(video_file))
83
  os.rename(video_file, dest_path)
84
 
85
- download_link = f"<a href='file/{dest_path}' download>📥 Download Video</a>"
86
  return dest_path, download_link, "Text-to-Video generated successfully!"
87
 
88
  @spaces.GPU(duration=get_duration)
@@ -91,8 +128,17 @@ def generate_i2v(image, prompt, size="1280*704", duration_seconds=5, steps=25, p
91
  return None, None, "Please upload an image and enter a prompt."
92
 
93
  temp_dir = tempfile.mkdtemp()
 
 
 
 
 
 
 
 
 
94
  image_path = os.path.join(temp_dir, "input.jpg")
95
- image.save(image_path)
96
 
97
  cmd = [
98
  "python", "generate.py",
@@ -112,6 +158,8 @@ def generate_i2v(image, prompt, size="1280*704", duration_seconds=5, steps=25, p
112
  except subprocess.CalledProcessError as e:
113
  return None, None, f"Error during I2V generation: {e}"
114
 
 
 
115
  video_file = find_latest_mp4()
116
  if video_file is None:
117
  return None, None, "Generation finished but no video file found."
@@ -119,13 +167,12 @@ def generate_i2v(image, prompt, size="1280*704", duration_seconds=5, steps=25, p
119
  dest_path = os.path.join(temp_dir, os.path.basename(video_file))
120
  os.rename(video_file, dest_path)
121
 
122
- download_link = f"<a href='file/{dest_path}' download>📥 Download Video</a>"
123
  return dest_path, download_link, "Image-to-Video generated successfully!"
124
 
125
- # ---------------- Step 4: Gradio UI ----------------
126
- with gr.Blocks() as demo:
127
-
128
 
 
129
  gr.Markdown("## 🎥 Wan2.2-TI2V-5B Video Generator")
130
  gr.Markdown("Choose **Text-to-Video** or **Image-to-Video** mode below.")
131
 
@@ -134,7 +181,7 @@ with gr.Blocks() as demo:
134
  label="Prompt",
135
  value="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage"
136
  )
137
- t2v_size = gr.Textbox(label="Video Size", value="1280*704")
138
  t2v_duration = gr.Number(label="Video Length (seconds)", value=5)
139
  t2v_steps = gr.Number(label="Inference Steps", value=25)
140
  t2v_btn = gr.Button("Generate from Text")
@@ -148,7 +195,7 @@ with gr.Blocks() as demo:
148
  )
149
 
150
  with gr.Tab("Image-to-Video"):
151
- i2v_image = gr.Image(type="pil", label="Upload Image")
152
  i2v_prompt = gr.Textbox(
153
  label="Prompt",
154
  value=(
@@ -160,7 +207,7 @@ with gr.Blocks() as demo:
160
  "intricate details and the refreshing atmosphere of the seaside."
161
  )
162
  )
163
- i2v_size = gr.Textbox(label="Video Size", value="1280*704")
164
  i2v_duration = gr.Number(label="Video Length (seconds)", value=5)
165
  i2v_steps = gr.Number(label="Inference Steps", value=25)
166
  i2v_btn = gr.Button("Generate from Image")
@@ -173,5 +220,17 @@ with gr.Blocks() as demo:
173
  [i2v_video, i2v_download, i2v_status]
174
  )
175
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  if __name__ == "__main__":
177
  demo.launch()
 
3
  import subprocess
4
  import tempfile
5
  import glob
6
+ import gc
7
  from huggingface_hub import snapshot_download
8
  import gradio as gr
9
+ from PIL import Image
10
+ import numpy as np
11
 
12
+ # -------- Model Download --------
13
  repo_id = "Wan-AI/Wan2.2-TI2V-5B"
14
  print(f"Downloading/loading checkpoints for {repo_id}...")
15
  ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
16
  print(f"Using checkpoints from {ckpt_dir}")
17
 
18
+ # -------- Constants --------
19
+ FIXED_FPS = 24
20
+ MIN_FRAMES_MODEL = 8
21
+ MAX_FRAMES_MODEL = 121
22
+ MOD_VALUE = 32
23
+ DEFAULT_H_SLIDER_VALUE = 704
24
+ DEFAULT_W_SLIDER_VALUE = 1280
25
+ NEW_FORMULA_MAX_AREA = 1280.0 * 704.0
26
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 1280
27
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 1280
28
+
29
+ # -------- Helpers --------
30
+ def _calculate_new_dimensions(pil_image):
31
+ orig_w, orig_h = pil_image.size
32
+ if orig_w <= 0 or orig_h <= 0:
33
+ return DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
34
+
35
+ aspect_ratio = orig_h / orig_w
36
+ calc_h = round(np.sqrt(NEW_FORMULA_MAX_AREA * aspect_ratio))
37
+ calc_w = round(np.sqrt(NEW_FORMULA_MAX_AREA / aspect_ratio))
38
+
39
+ calc_h = max(MOD_VALUE, (calc_h // MOD_VALUE) * MOD_VALUE)
40
+ calc_w = max(MOD_VALUE, (calc_w // MOD_VALUE) * MOD_VALUE)
41
+
42
+ new_h = int(np.clip(calc_h, SLIDER_MIN_H, (SLIDER_MAX_H // MOD_VALUE) * MOD_VALUE))
43
+ new_w = int(np.clip(calc_w, SLIDER_MIN_W, (SLIDER_MAX_W // MOD_VALUE) * MOD_VALUE))
44
+
45
+ return new_h, new_w
46
+
47
+ def handle_image_upload_for_dims(uploaded_pil_image, current_h_val, current_w_val):
48
+ if uploaded_pil_image is None:
49
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
50
  try:
51
+ if hasattr(uploaded_pil_image, 'shape'):
52
+ pil_image = Image.fromarray(uploaded_pil_image).convert("RGB")
53
+ else:
54
+ pil_image = uploaded_pil_image
55
+ new_h, new_w = _calculate_new_dimensions(pil_image)
56
+ return gr.update(value=new_h), gr.update(value=new_w)
57
  except Exception:
58
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
 
 
 
 
 
 
 
 
 
 
59
 
60
+ def get_duration(prompt, size, duration_seconds, steps, progress):
 
 
 
 
 
61
  if duration_seconds >= 3:
62
  return 220
63
  elif steps > 35 and duration_seconds >= 2:
 
67
  else:
68
  return 90
69
 
70
+ def find_latest_mp4():
71
+ files = glob.glob("*.mp4")
72
+ if not files:
73
+ return None
74
+ latest_file = max(files, key=os.path.getctime)
75
+ return latest_file
76
 
77
+ # -------- Generation Functions --------
78
 
79
+ @spaces.GPU(duration=get_duration)
80
  def generate_t2v(prompt, size="1280*704", duration_seconds=5, steps=25, progress=gr.Progress(track_tqdm=True)):
81
  if not prompt.strip():
82
  return None, None, "Please enter a prompt."
83
 
84
  temp_dir = tempfile.mkdtemp()
85
 
86
+ # Ensure size is multiples of MOD_VALUE (h*w)
87
+ try:
88
+ h, w = size.lower().replace(" ", "").split("*")
89
+ h = max(MOD_VALUE, (int(h) // MOD_VALUE) * MOD_VALUE)
90
+ w = max(MOD_VALUE, (int(w) // MOD_VALUE) * MOD_VALUE)
91
+ size = f"{h}*{w}"
92
+ except Exception:
93
+ size = f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}"
94
+
95
  cmd = [
96
  "python", "generate.py",
97
  "--task", "ti2v-5B",
 
110
  except subprocess.CalledProcessError as e:
111
  return None, None, f"Error during T2V generation: {e}"
112
 
113
+ gc.collect()
114
+
115
  video_file = find_latest_mp4()
116
  if video_file is None:
117
  return None, None, "Generation finished but no video file found."
 
119
  dest_path = os.path.join(temp_dir, os.path.basename(video_file))
120
  os.rename(video_file, dest_path)
121
 
122
+ download_link = f"<a href='{os.path.basename(dest_path)}' download>📥 Download Video</a>"
123
  return dest_path, download_link, "Text-to-Video generated successfully!"
124
 
125
  @spaces.GPU(duration=get_duration)
 
128
  return None, None, "Please upload an image and enter a prompt."
129
 
130
  temp_dir = tempfile.mkdtemp()
131
+
132
+ try:
133
+ h, w = size.lower().replace(" ", "").split("*")
134
+ h = max(MOD_VALUE, (int(h) // MOD_VALUE) * MOD_VALUE)
135
+ w = max(MOD_VALUE, (int(w) // MOD_VALUE) * MOD_VALUE)
136
+ size = f"{h}*{w}"
137
+ except Exception:
138
+ size = f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}"
139
+
140
  image_path = os.path.join(temp_dir, "input.jpg")
141
+ Image.fromarray(image).save(image_path)
142
 
143
  cmd = [
144
  "python", "generate.py",
 
158
  except subprocess.CalledProcessError as e:
159
  return None, None, f"Error during I2V generation: {e}"
160
 
161
+ gc.collect()
162
+
163
  video_file = find_latest_mp4()
164
  if video_file is None:
165
  return None, None, "Generation finished but no video file found."
 
167
  dest_path = os.path.join(temp_dir, os.path.basename(video_file))
168
  os.rename(video_file, dest_path)
169
 
170
+ download_link = f"<a href='{os.path.basename(dest_path)}' download>📥 Download Video</a>"
171
  return dest_path, download_link, "Image-to-Video generated successfully!"
172
 
173
+ # -------- Gradio UI --------
 
 
174
 
175
+ with gr.Blocks() as demo:
176
  gr.Markdown("## 🎥 Wan2.2-TI2V-5B Video Generator")
177
  gr.Markdown("Choose **Text-to-Video** or **Image-to-Video** mode below.")
178
 
 
181
  label="Prompt",
182
  value="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage"
183
  )
184
+ t2v_size = gr.Textbox(label="Video Size (HxW)", value=f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}")
185
  t2v_duration = gr.Number(label="Video Length (seconds)", value=5)
186
  t2v_steps = gr.Number(label="Inference Steps", value=25)
187
  t2v_btn = gr.Button("Generate from Text")
 
195
  )
196
 
197
  with gr.Tab("Image-to-Video"):
198
+ i2v_image = gr.Image(type="numpy", label="Upload Image")
199
  i2v_prompt = gr.Textbox(
200
  label="Prompt",
201
  value=(
 
207
  "intricate details and the refreshing atmosphere of the seaside."
208
  )
209
  )
210
+ i2v_size = gr.Textbox(label="Video Size (HxW)", value=f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}")
211
  i2v_duration = gr.Number(label="Video Length (seconds)", value=5)
212
  i2v_steps = gr.Number(label="Inference Steps", value=25)
213
  i2v_btn = gr.Button("Generate from Image")
 
220
  [i2v_video, i2v_download, i2v_status]
221
  )
222
 
223
+ # Auto adjust size on image upload for i2v
224
+ i2v_image.upload(
225
+ fn=handle_image_upload_for_dims,
226
+ inputs=[i2v_image, i2v_size, i2v_size],
227
+ outputs=[i2v_size, i2v_size]
228
+ )
229
+ i2v_image.clear(
230
+ fn=handle_image_upload_for_dims,
231
+ inputs=[i2v_image, i2v_size, i2v_size],
232
+ outputs=[i2v_size, i2v_size]
233
+ )
234
+
235
  if __name__ == "__main__":
236
  demo.launch()