Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,48 +1,357 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
import os
|
4 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch
|
6 |
-
from
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
|
|
16 |
|
17 |
-
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
|
|
|
|
28 |
try:
|
29 |
-
# Load
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
-
#
|
33 |
-
|
34 |
|
35 |
-
#
|
|
|
|
|
|
|
36 |
with torch.no_grad():
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
|
|
40 |
|
41 |
-
|
42 |
-
|
|
|
43 |
|
|
|
|
|
44 |
except Exception as e:
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import ffmpeg
|
3 |
import os
|
4 |
+
import uuid
|
5 |
+
import requests
|
6 |
+
import tempfile
|
7 |
+
import shutil
|
8 |
+
import re
|
9 |
+
import time
|
10 |
+
import concurrent.futures
|
11 |
import torch
|
12 |
+
from pathlib import Path
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
|
15 |
+
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
16 |
+
from PIL import Image
|
17 |
|
18 |
+
# Add GPU decorator for Hugging Face Spaces
|
19 |
+
try:
|
20 |
+
from spaces import GPU
|
21 |
+
use_gpu = True
|
22 |
+
@GPU
|
23 |
+
def get_gpu():
|
24 |
+
return True
|
25 |
+
# Call the function to trigger GPU allocation
|
26 |
+
get_gpu()
|
27 |
+
except ImportError:
|
28 |
+
use_gpu = False
|
29 |
+
print("Running without GPU acceleration")
|
30 |
|
31 |
+
# Load environment variables from .env file if it exists
|
32 |
+
load_dotenv()
|
33 |
|
34 |
+
# Global variables to hold models (lazy loading)
|
35 |
+
llava_model = None
|
36 |
+
llava_processor = None
|
37 |
+
stable_diffusion_pipeline = None
|
38 |
|
39 |
+
def load_llava_model():
|
40 |
+
"""Load LLaVA model for image captioning"""
|
41 |
+
global llava_model, llava_processor
|
42 |
+
|
43 |
+
if llava_model is None or llava_processor is None:
|
44 |
+
print("Loading LLaVA model for image analysis...")
|
45 |
+
model_id = "llava-hf/llava-1.5-7b-hf"
|
46 |
+
|
47 |
+
# Load processor and model (with reduced precision for memory efficiency)
|
48 |
+
llava_processor = AutoProcessor.from_pretrained(model_id)
|
49 |
+
llava_model = AutoModelForCausalLM.from_pretrained(
|
50 |
+
model_id,
|
51 |
+
torch_dtype=torch.float16,
|
52 |
+
device_map="auto"
|
53 |
+
)
|
54 |
+
|
55 |
+
return llava_model, llava_processor
|
56 |
|
57 |
+
def load_stable_diffusion_model():
|
58 |
+
"""Load Stable Diffusion model for Ghibli-style image generation"""
|
59 |
+
global stable_diffusion_pipeline
|
60 |
+
|
61 |
+
if stable_diffusion_pipeline is None:
|
62 |
+
print("Loading Stable Diffusion model for image generation...")
|
63 |
+
model_id = "nitrosocke/Ghibli-Diffusion"
|
64 |
+
|
65 |
+
# Load the pipeline with precision to balance memory usage and quality
|
66 |
+
stable_diffusion_pipeline = StableDiffusionPipeline.from_pretrained(
|
67 |
+
model_id,
|
68 |
+
torch_dtype=torch.float16,
|
69 |
+
safety_checker=None # Disable safety checker for performance
|
70 |
+
)
|
71 |
+
|
72 |
+
# Move to GPU if available
|
73 |
+
if torch.cuda.is_available():
|
74 |
+
stable_diffusion_pipeline = stable_diffusion_pipeline.to("cuda")
|
75 |
+
|
76 |
+
# Use the DPM-Solver++ scheduler for better quality at lower steps
|
77 |
+
stable_diffusion_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
78 |
+
stable_diffusion_pipeline.scheduler.config,
|
79 |
+
algorithm_type="dpmsolver++",
|
80 |
+
use_karras_sigmas=True
|
81 |
+
)
|
82 |
+
|
83 |
+
return stable_diffusion_pipeline
|
84 |
|
85 |
+
def analyze_image_with_llava(image_path):
|
86 |
+
"""Process a single frame with LLaVA to generate a description"""
|
87 |
try:
|
88 |
+
# Load the model if not already loaded
|
89 |
+
model, processor = load_llava_model()
|
90 |
+
|
91 |
+
# Load the image
|
92 |
+
image = Image.open(image_path)
|
93 |
|
94 |
+
# Prompt for Ghibli-specific description
|
95 |
+
prompt = "Describe this image in detail, focusing on elements that would be important to recreate it in Studio Ghibli animation style."
|
96 |
|
97 |
+
# Process the image and generate text
|
98 |
+
inputs = processor(prompt, image, return_tensors="pt").to(model.device)
|
99 |
+
|
100 |
+
# Generate with appropriate parameters
|
101 |
with torch.no_grad():
|
102 |
+
output = model.generate(
|
103 |
+
**inputs,
|
104 |
+
max_new_tokens=300,
|
105 |
+
do_sample=True,
|
106 |
+
temperature=0.7,
|
107 |
+
top_p=0.9,
|
108 |
+
)
|
109 |
|
110 |
+
# Decode the output
|
111 |
+
generated_text = processor.decode(output[0], skip_special_tokens=True)
|
112 |
|
113 |
+
# Extract just the response part (remove the prompt)
|
114 |
+
response = generated_text.split(prompt)[-1].strip()
|
115 |
+
print(f"LLaVA analysis for frame {os.path.basename(image_path)}: {response[:150]}...")
|
116 |
|
117 |
+
return response
|
118 |
+
|
119 |
except Exception as e:
|
120 |
+
import traceback
|
121 |
+
print(f"Error analyzing image {os.path.basename(image_path)}: {str(e)}")
|
122 |
+
print(traceback.format_exc())
|
123 |
+
return f"Error analyzing image: {str(e)}"
|
124 |
+
|
125 |
+
def generate_ghibli_image(image_description, style_prompt, output_path):
|
126 |
+
"""Generate a Ghibli-style image based on the description using Stable Diffusion"""
|
127 |
+
try:
|
128 |
+
# Load the model if not already loaded
|
129 |
+
pipeline = load_stable_diffusion_model()
|
130 |
+
|
131 |
+
# Combine the image description with the style prompt
|
132 |
+
full_prompt = f"{image_description}. {style_prompt}. Hand-drawn animation style, soft colors, attention to detail, Miyazaki aesthetic."
|
133 |
+
|
134 |
+
# Ensure prompt isn't too long
|
135 |
+
if len(full_prompt) > 500:
|
136 |
+
full_prompt = full_prompt[:497] + "..."
|
137 |
+
|
138 |
+
# Generate the image
|
139 |
+
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
|
140 |
+
image = pipeline(
|
141 |
+
prompt=full_prompt,
|
142 |
+
negative_prompt="3d, cgi, low quality, blurry, distorted, deformed",
|
143 |
+
num_inference_steps=30,
|
144 |
+
guidance_scale=7.5,
|
145 |
+
width=768,
|
146 |
+
height=768,
|
147 |
+
).images[0]
|
148 |
+
|
149 |
+
# Save the generated image
|
150 |
+
image.save(output_path)
|
151 |
+
print(f"Successfully saved stylized frame: {os.path.basename(output_path)}")
|
152 |
+
return True
|
153 |
|
154 |
+
except Exception as e:
|
155 |
+
import traceback
|
156 |
+
print(f"Error generating image: {str(e)}")
|
157 |
+
print(traceback.format_exc())
|
158 |
+
return False
|
159 |
+
|
160 |
+
def process_frame(frame_path, style_prompt):
|
161 |
+
"""Process a single frame with LLaVA analysis and Stable Diffusion generation"""
|
162 |
+
try:
|
163 |
+
# First use LLaVA to analyze the image
|
164 |
+
image_description = analyze_image_with_llava(frame_path)
|
165 |
+
|
166 |
+
if image_description.startswith("Error"):
|
167 |
+
return False
|
168 |
+
|
169 |
+
# Now use Stable Diffusion to generate a stylized version
|
170 |
+
result = generate_ghibli_image(image_description, style_prompt, frame_path)
|
171 |
+
|
172 |
+
return result
|
173 |
+
|
174 |
+
except Exception as e:
|
175 |
+
import traceback
|
176 |
+
print(f"Error processing frame {os.path.basename(frame_path)}: {str(e)}")
|
177 |
+
print(traceback.format_exc())
|
178 |
+
return False
|
179 |
+
|
180 |
+
def stylize_video(video_path, style_prompt, num_frames=15):
|
181 |
+
try:
|
182 |
+
# Create temp directories
|
183 |
+
temp_dir = tempfile.mkdtemp()
|
184 |
+
input_filename = os.path.join(temp_dir, "input.mp4")
|
185 |
+
frames_dir = os.path.join(temp_dir, "frames")
|
186 |
+
os.makedirs(frames_dir, exist_ok=True)
|
187 |
+
|
188 |
+
# Save the input video to a temporary file
|
189 |
+
if isinstance(video_path, str):
|
190 |
+
if video_path.startswith('http'):
|
191 |
+
# It's a URL, download it
|
192 |
+
response = requests.get(video_path, stream=True)
|
193 |
+
with open(input_filename, 'wb') as f:
|
194 |
+
for chunk in response.iter_content(chunk_size=8192):
|
195 |
+
f.write(chunk)
|
196 |
+
elif os.path.exists(video_path):
|
197 |
+
# It's a file path, copy it
|
198 |
+
shutil.copy(video_path, input_filename)
|
199 |
+
else:
|
200 |
+
return None, f"Video file not found: {video_path}"
|
201 |
+
else:
|
202 |
+
# Assume it's binary data
|
203 |
+
with open(input_filename, "wb") as f:
|
204 |
+
f.write(video_path)
|
205 |
+
|
206 |
+
# Make sure the video file exists
|
207 |
+
if not os.path.exists(input_filename):
|
208 |
+
return None, "Failed to save input video"
|
209 |
+
|
210 |
+
# Extract frames - using lower fps for longer videos (1 frame per second)
|
211 |
+
ffmpeg.input(input_filename).output(f"{frames_dir}/%04d.png", vf="fps=1").run(quiet=True)
|
212 |
+
|
213 |
+
# Check if frames were extracted
|
214 |
+
frames = sorted([os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith('.png')])
|
215 |
+
if not frames:
|
216 |
+
return None, "No frames were extracted from the video"
|
217 |
+
|
218 |
+
# Limit to a maximum number of frames for reasonable processing times
|
219 |
+
if len(frames) > num_frames:
|
220 |
+
# Take evenly distributed frames
|
221 |
+
indices = [int(i * (len(frames) - 1) / (num_frames - 1)) for i in range(num_frames)]
|
222 |
+
frames = [frames[i] for i in indices]
|
223 |
+
|
224 |
+
print(f"Processing {len(frames)} frames")
|
225 |
+
|
226 |
+
# Process frames sequentially if we're using a GPU (to avoid CUDA OOM errors)
|
227 |
+
# Otherwise, use a modest level of parallelism
|
228 |
+
if torch.cuda.is_available():
|
229 |
+
# Sequential processing to avoid CUDA OOM errors
|
230 |
+
processed_frames = []
|
231 |
+
for i, frame in enumerate(frames):
|
232 |
+
success = process_frame(frame, style_prompt)
|
233 |
+
if success:
|
234 |
+
processed_frames.append(frame)
|
235 |
+
print(f"Completed frame {os.path.basename(frame)} ({i+1}/{len(frames)})")
|
236 |
+
else:
|
237 |
+
print(f"Failed to process frame {os.path.basename(frame)}")
|
238 |
+
|
239 |
+
# Free up CUDA cache between frames
|
240 |
+
torch.cuda.empty_cache()
|
241 |
+
else:
|
242 |
+
# Process frames in parallel with limited workers if no GPU
|
243 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
244 |
+
futures = {executor.submit(process_frame, frame, style_prompt): frame for frame in frames}
|
245 |
+
|
246 |
+
# Collect results
|
247 |
+
processed_frames = []
|
248 |
+
for future in concurrent.futures.as_completed(futures):
|
249 |
+
frame = futures[future]
|
250 |
+
if future.result():
|
251 |
+
processed_frames.append(frame)
|
252 |
+
print(f"Completed frame {os.path.basename(frame)} ({len(processed_frames)}/{len(frames)})")
|
253 |
+
|
254 |
+
if not processed_frames:
|
255 |
+
return None, "Failed to process any frames. Please check the logs for more information."
|
256 |
+
|
257 |
+
# Even if not all frames were processed, try to create a video with what we have
|
258 |
+
print(f"Successfully processed {len(processed_frames)}/{len(frames)} frames")
|
259 |
+
|
260 |
+
# Ensure frames are in the correct order (important for video continuity)
|
261 |
+
processed_frames.sort()
|
262 |
+
|
263 |
+
# Reassemble frames into video
|
264 |
+
output_filename = os.path.join(temp_dir, "stylized.mp4")
|
265 |
+
|
266 |
+
# Use a higher bitrate and better codec for higher quality
|
267 |
+
ffmpeg.input(f"{frames_dir}/%04d.png", framerate=1) \
|
268 |
+
.output(output_filename, vcodec='libx264', pix_fmt='yuv420p', crf=18) \
|
269 |
+
.run(quiet=True)
|
270 |
+
|
271 |
+
# Check if the output file exists and has content
|
272 |
+
if not os.path.exists(output_filename) or os.path.getsize(output_filename) == 0:
|
273 |
+
return None, "Failed to create output video"
|
274 |
+
|
275 |
+
# Copy to a persistent location for Gradio to serve
|
276 |
+
os.makedirs("outputs", exist_ok=True)
|
277 |
+
persistent_output = os.path.join("outputs", f"stylized_{uuid.uuid4()}.mp4")
|
278 |
+
shutil.copy(output_filename, persistent_output)
|
279 |
+
|
280 |
+
# Return the relative path (Gradio can handle this)
|
281 |
+
print(f"Output video created at: {persistent_output}")
|
282 |
+
|
283 |
+
# Cleanup temp files
|
284 |
+
shutil.rmtree(temp_dir)
|
285 |
+
|
286 |
+
return persistent_output, f"Video stylized successfully with {len(processed_frames)} frames!"
|
287 |
+
|
288 |
+
except Exception as e:
|
289 |
+
import traceback
|
290 |
+
traceback_str = traceback.format_exc()
|
291 |
+
print(f"Error: {str(e)}\n{traceback_str}")
|
292 |
+
return None, f"Error: {str(e)}"
|
293 |
+
|
294 |
+
# Use Gradio examples feature with local files
|
295 |
+
example_videos = [
|
296 |
+
["sample_video.mp4", "Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"]
|
297 |
+
]
|
298 |
+
|
299 |
+
with gr.Blocks(title="Video-to-Ghibli Style Converter (Open Source)") as iface:
|
300 |
+
gr.Markdown("# Video-to-Ghibli Style Converter (Open Source)")
|
301 |
+
gr.Markdown("Upload a video and convert it to Studio Ghibli animation style using LLaVA and Stable Diffusion.")
|
302 |
+
|
303 |
+
with gr.Row():
|
304 |
+
with gr.Column(scale=2):
|
305 |
+
# Main input column
|
306 |
+
video_input = gr.Video(label="Upload Video (up to 15 seconds)")
|
307 |
+
|
308 |
+
style_prompt = gr.Textbox(
|
309 |
+
label="Style Prompt",
|
310 |
+
value="Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"
|
311 |
+
)
|
312 |
+
|
313 |
+
num_frames_slider = gr.Slider(
|
314 |
+
minimum=5,
|
315 |
+
maximum=15,
|
316 |
+
value=10,
|
317 |
+
step=1,
|
318 |
+
label="Number of frames to process"
|
319 |
+
)
|
320 |
+
|
321 |
+
submit_btn = gr.Button("Stylize Video", variant="primary")
|
322 |
+
|
323 |
+
with gr.Column(scale=2):
|
324 |
+
# Output column
|
325 |
+
video_output = gr.Video(label="Stylized Video")
|
326 |
+
status_output = gr.Textbox(label="Status", value="Ready. Upload a video to start.")
|
327 |
+
|
328 |
+
submit_btn.click(
|
329 |
+
fn=stylize_video,
|
330 |
+
inputs=[video_input, style_prompt, num_frames_slider],
|
331 |
+
outputs=[video_output, status_output]
|
332 |
+
)
|
333 |
+
|
334 |
+
gr.Markdown("""
|
335 |
+
## Instructions
|
336 |
+
1. Upload a video up to 15 seconds long
|
337 |
+
2. Customize the style prompt if desired
|
338 |
+
3. Adjust the number of frames to process (fewer = faster)
|
339 |
+
4. Click "Stylize Video" and wait for processing
|
340 |
+
|
341 |
+
## Example Style Prompts
|
342 |
+
- "Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"
|
343 |
+
- "Studio Ghibli style with magical and dreamy atmosphere"
|
344 |
+
- "Nostalgic Studio Ghibli animation style with watercolor backgrounds and clean linework"
|
345 |
+
- "Ghibli-inspired animation with vibrant colors and fantasy elements"
|
346 |
+
|
347 |
+
Note: Each frame is analyzed by LLaVA-1.5-7B and then transformed by Stable Diffusion (Ghibli-Diffusion model).
|
348 |
+
Videos are processed at 1 frame per second to keep processing time reasonable.
|
349 |
+
|
350 |
+
## Technical Details
|
351 |
+
- Image Analysis: Using LLaVA-1.5-7B for frame understanding and description
|
352 |
+
- Image Generation: Using Stable Diffusion (nitrosocke/Ghibli-Diffusion) for style transfer
|
353 |
+
- All processing happens locally - no API keys needed!
|
354 |
+
""")
|
355 |
+
|
356 |
+
if __name__ == "__main__":
|
357 |
+
iface.launch()
|