RajatMalviya commited on
Commit
f673b98
·
verified ·
1 Parent(s): bea810c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -400
app.py CHANGED
@@ -1,416 +1,119 @@
1
- import gradio as gr
2
- import ffmpeg
3
  import os
4
- import uuid
5
- import requests
6
- import tempfile
7
- import shutil
8
- import re
9
  import time
10
- import concurrent.futures
11
  import torch
12
- from pathlib import Path
13
- from dotenv import load_dotenv
14
- from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
15
- from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
16
  from PIL import Image
17
- import tempfile
18
-
19
-
20
- # os.makedirs("./hf_cache", exist_ok=True)
21
- # os.environ["HF_HOME"] = "./hf_cache"
22
- # os.environ["TRANSFORMERS_CACHE"] = "./hf_cache/transformers"
23
- # os.environ["HUGGINGFACE_HUB_CACHE"] = "./hf_cache/hub"
24
-
25
- # Use system temp directories which should be writable
26
- TMP_DIR = tempfile.gettempdir()
27
- MODEL_DIR = os.path.join(TMP_DIR, "hf_models")
28
-
29
- # Set environment variables to use these directories
30
- os.environ["TRANSFORMERS_CACHE"] = os.path.join(TMP_DIR, "transformers_cache")
31
- os.environ["HF_HOME"] = os.path.join(TMP_DIR, "hf_home")
32
- os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(TMP_DIR, "hf_hub_cache")
33
-
34
- # Helper function to safely create directories
35
- def safe_makedirs(directory):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
- os.makedirs(directory, exist_ok=True)
38
- return True
39
- except (PermissionError, OSError) as e:
40
- print(f"Warning: Could not create directory {directory}: {e}")
41
- return False
42
-
43
- # Create necessary directories
44
- for directory in [MODEL_DIR, os.environ["TRANSFORMERS_CACHE"],
45
- os.environ["HF_HOME"], os.environ["HUGGINGFACE_HUB_CACHE"]]:
46
- safe_makedirs(directory)
47
-
48
-
49
-
50
-
51
-
52
- # Add GPU decorator for Hugging Face Spaces
53
- try:
54
- from spaces import GPU
55
- use_gpu = True
56
- @GPU
57
- def get_gpu():
58
- return True
59
- # Call the function to trigger GPU allocation
60
- get_gpu()
61
- except ImportError:
62
- use_gpu = False
63
- print("Running without GPU acceleration")
64
-
65
- # Load environment variables from .env file if it exists
66
- load_dotenv()
67
-
68
- # Global variables to hold models (lazy loading)
69
- llava_model = None
70
- llava_processor = None
71
- stable_diffusion_pipeline = None
72
-
73
- # Set up the model directory
74
- MODEL_DIR = "./model"
75
- os.makedirs(MODEL_DIR, exist_ok=True)
76
-
77
- def load_llava_model():
78
- """Load LLaVA model for image captioning"""
79
- global llava_model, llava_processor
80
-
81
- if llava_model is None or llava_processor is None:
82
- print("Loading LLaVA model for image analysis...")
83
- model_id = "llava-hf/llava-1.5-7b-hf"
84
 
 
 
85
  try:
86
- # Load processor and model with system temp directory
87
- llava_processor = AutoProcessor.from_pretrained(
88
- model_id,
89
- local_files_only=False
90
- )
91
- llava_model = AutoModelForCausalLM.from_pretrained(
92
- model_id,
93
- torch_dtype=torch.float16,
94
- device_map="auto",
95
- local_files_only=False
96
- )
97
  except Exception as e:
98
- print(f"Error loading LLaVA model: {e}")
99
- raise
100
-
101
- return llava_model, llava_processor
102
-
103
-
104
- # In the stylize_video function, replace:
105
- os.makedirs("outputs", exist_ok=True)
106
- persistent_output = os.path.join("outputs", f"stylized_{uuid.uuid4()}.mp4")
107
-
108
- # With:
109
- outputs_dir = os.path.join(TMP_DIR, "outputs")
110
- safe_makedirs(outputs_dir)
111
- persistent_output = os.path.join(outputs_dir, f"stylized_{uuid.uuid4()}.mp4")
112
-
113
-
114
- def load_stable_diffusion_model():
115
- """Load Stable Diffusion model for Ghibli-style image generation"""
116
- global stable_diffusion_pipeline
117
-
118
- if stable_diffusion_pipeline is None:
119
- print("Loading Stable Diffusion model for image generation...")
120
- model_id = "nitrosocke/Ghibli-Diffusion"
121
-
122
- # Load the pipeline with precision to balance memory usage and quality
123
- stable_diffusion_pipeline = StableDiffusionPipeline.from_pretrained(
124
- model_id,
125
- torch_dtype=torch.float16,
126
- safety_checker=None, # Disable safety checker for performance
127
- cache_dir=os.path.join(MODEL_DIR, "stable_diffusion")
128
- )
129
-
130
-
131
- # Move to GPU if available
132
- if torch.cuda.is_available():
133
- stable_diffusion_pipeline = stable_diffusion_pipeline.to("cuda")
134
-
135
- # Use the DPM-Solver++ scheduler for better quality at lower steps
136
- stable_diffusion_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
137
- stable_diffusion_pipeline.scheduler.config,
138
- algorithm_type="dpmsolver++",
139
- use_karras_sigmas=True
140
- )
141
-
142
- return stable_diffusion_pipeline
143
-
144
- def analyze_image_with_llava(image_path):
145
- """Process a single frame with LLaVA to generate a description"""
146
- try:
147
- # Load the model if not already loaded
148
- model, processor = load_llava_model()
149
-
150
- # Load the image
151
- image = Image.open(image_path)
152
-
153
- # Prompt for Ghibli-specific description
154
- prompt = "Describe this image in detail, focusing on elements that would be important to recreate it in Studio Ghibli animation style."
155
-
156
- # Process the image and generate text
157
- inputs = processor(prompt, image, return_tensors="pt").to(model.device)
158
-
159
- # Generate with appropriate parameters
160
- with torch.no_grad():
161
- output = model.generate(
162
- **inputs,
163
- max_new_tokens=300,
164
- do_sample=True,
165
- temperature=0.7,
166
- top_p=0.9,
167
- )
168
-
169
- # Decode the output
170
- generated_text = processor.decode(output[0], skip_special_tokens=True)
171
-
172
- # Extract just the response part (remove the prompt)
173
- response = generated_text.split(prompt)[-1].strip()
174
- print(f"LLaVA analysis for frame {os.path.basename(image_path)}: {response[:150]}...")
175
-
176
- return response
177
-
178
- except Exception as e:
179
- import traceback
180
- print(f"Error analyzing image {os.path.basename(image_path)}: {str(e)}")
181
- print(traceback.format_exc())
182
- return f"Error analyzing image: {str(e)}"
183
-
184
- def generate_ghibli_image(image_description, style_prompt, output_path):
185
- """Generate a Ghibli-style image based on the description using Stable Diffusion"""
186
- try:
187
- # Load the model if not already loaded
188
- pipeline = load_stable_diffusion_model()
189
-
190
- # Combine the image description with the style prompt
191
- full_prompt = f"{image_description}. {style_prompt}. Hand-drawn animation style, soft colors, attention to detail, Miyazaki aesthetic."
192
-
193
- # Ensure prompt isn't too long
194
- if len(full_prompt) > 500:
195
- full_prompt = full_prompt[:497] + "..."
196
-
197
- # Generate the image
198
- with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
199
- image = pipeline(
200
- prompt=full_prompt,
201
- negative_prompt="3d, cgi, low quality, blurry, distorted, deformed",
202
- num_inference_steps=30,
203
- guidance_scale=7.5,
204
- width=768,
205
- height=768,
206
  ).images[0]
207
 
208
- # Save the generated image
209
- image.save(output_path)
210
- print(f"Successfully saved stylized frame: {os.path.basename(output_path)}")
211
- return True
212
-
213
- except Exception as e:
214
- import traceback
215
- print(f"Error generating image: {str(e)}")
216
- print(traceback.format_exc())
217
- return False
218
-
219
- def process_frame(frame_path, style_prompt):
220
- """Process a single frame with LLaVA analysis and Stable Diffusion generation"""
221
- try:
222
- # First use LLaVA to analyze the image
223
- image_description = analyze_image_with_llava(frame_path)
224
-
225
- if image_description.startswith("Error"):
226
- return False
227
 
228
- # Now use Stable Diffusion to generate a stylized version
229
- result = generate_ghibli_image(image_description, style_prompt, frame_path)
 
 
230
 
231
- return result
232
-
233
- except Exception as e:
234
- import traceback
235
- print(f"Error processing frame {os.path.basename(frame_path)}: {str(e)}")
236
- print(traceback.format_exc())
237
- return False
238
-
239
- def stylize_video(video_path, style_prompt, num_frames=15):
240
- try:
241
- # Create temp directories
242
- temp_dir = tempfile.mkdtemp()
243
- input_filename = os.path.join(temp_dir, "input.mp4")
244
- frames_dir = os.path.join(temp_dir, "frames")
245
- os.makedirs(frames_dir, exist_ok=True)
246
-
247
- # Save the input video to a temporary file
248
- if isinstance(video_path, str):
249
- if video_path.startswith('http'):
250
- # It's a URL, download it
251
- response = requests.get(video_path, stream=True)
252
- with open(input_filename, 'wb') as f:
253
- for chunk in response.iter_content(chunk_size=8192):
254
- f.write(chunk)
255
- elif os.path.exists(video_path):
256
- # It's a file path, copy it
257
- shutil.copy(video_path, input_filename)
258
- else:
259
- return None, f"Video file not found: {video_path}"
260
- else:
261
- # Assume it's binary data
262
- with open(input_filename, "wb") as f:
263
- f.write(video_path)
264
-
265
- # Make sure the video file exists
266
- if not os.path.exists(input_filename):
267
- return None, "Failed to save input video"
268
-
269
- # Extract frames - using lower fps for longer videos (1 frame per second)
270
- ffmpeg.input(input_filename).output(f"{frames_dir}/%04d.png", vf="fps=1").run(quiet=True)
271
-
272
- # Check if frames were extracted
273
- frames = sorted([os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith('.png')])
274
- if not frames:
275
- return None, "No frames were extracted from the video"
276
-
277
- # Limit to a maximum number of frames for reasonable processing times
278
- if len(frames) > num_frames:
279
- # Take evenly distributed frames
280
- indices = [int(i * (len(frames) - 1) / (num_frames - 1)) for i in range(num_frames)]
281
- frames = [frames[i] for i in indices]
282
-
283
- print(f"Processing {len(frames)} frames")
284
-
285
- # Process frames sequentially if we're using a GPU (to avoid CUDA OOM errors)
286
- # Otherwise, use a modest level of parallelism
287
- if torch.cuda.is_available():
288
- # Sequential processing to avoid CUDA OOM errors
289
- processed_frames = []
290
- for i, frame in enumerate(frames):
291
- success = process_frame(frame, style_prompt)
292
- if success:
293
- processed_frames.append(frame)
294
- print(f"Completed frame {os.path.basename(frame)} ({i+1}/{len(frames)})")
295
- else:
296
- print(f"Failed to process frame {os.path.basename(frame)}")
297
-
298
- # Free up CUDA cache between frames
299
- torch.cuda.empty_cache()
300
- else:
301
- # Process frames in parallel with limited workers if no GPU
302
- with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
303
- futures = {executor.submit(process_frame, frame, style_prompt): frame for frame in frames}
304
-
305
- # Collect results
306
- processed_frames = []
307
- for future in concurrent.futures.as_completed(futures):
308
- frame = futures[future]
309
- if future.result():
310
- processed_frames.append(frame)
311
- print(f"Completed frame {os.path.basename(frame)} ({len(processed_frames)}/{len(frames)})")
312
-
313
- if not processed_frames:
314
- return None, "Failed to process any frames. Please check the logs for more information."
315
-
316
- # Even if not all frames were processed, try to create a video with what we have
317
- print(f"Successfully processed {len(processed_frames)}/{len(frames)} frames")
318
-
319
- # Ensure frames are in the correct order (important for video continuity)
320
- processed_frames.sort()
321
-
322
- # Reassemble frames into video
323
- output_filename = os.path.join(temp_dir, "stylized.mp4")
324
-
325
- # Use a higher bitrate and better codec for higher quality
326
- ffmpeg.input(f"{frames_dir}/%04d.png", framerate=1) \
327
- .output(output_filename, vcodec='libx264', pix_fmt='yuv420p', crf=18) \
328
- .run(quiet=True)
329
-
330
- # Check if the output file exists and has content
331
- if not os.path.exists(output_filename) or os.path.getsize(output_filename) == 0:
332
- return None, "Failed to create output video"
333
-
334
- # Copy to a persistent location for Gradio to serve
335
- os.makedirs("outputs", exist_ok=True)
336
- persistent_output = os.path.join("outputs", f"stylized_{uuid.uuid4()}.mp4")
337
- shutil.copy(output_filename, persistent_output)
338
-
339
- # Return the relative path (Gradio can handle this)
340
- print(f"Output video created at: {persistent_output}")
341
-
342
- # Cleanup temp files
343
- shutil.rmtree(temp_dir)
344
-
345
- return persistent_output, f"Video stylized successfully with {len(processed_frames)} frames!"
346
 
 
 
347
  except Exception as e:
348
- import traceback
349
- traceback_str = traceback.format_exc()
350
- print(f"Error: {str(e)}\n{traceback_str}")
351
- return None, f"Error: {str(e)}"
352
-
353
- # Use Gradio examples feature with local files
354
- example_videos = [
355
- ["sample_video.mp4", "Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"]
356
- ]
357
-
358
- with gr.Blocks(title="Video-to-Ghibli Style Converter (Open Source)") as iface:
359
- gr.Markdown("# Video-to-Ghibli Style Converter (Open Source)")
360
- gr.Markdown("Upload a video and convert it to Studio Ghibli animation style using LLaVA and Stable Diffusion.")
361
-
362
- with gr.Row():
363
- with gr.Column(scale=2):
364
- # Main input column
365
- video_input = gr.Video(label="Upload Video (up to 15 seconds)")
366
-
367
- style_prompt = gr.Textbox(
368
- label="Style Prompt",
369
- value="Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"
370
- )
371
-
372
- num_frames_slider = gr.Slider(
373
- minimum=5,
374
- maximum=15,
375
- value=10,
376
- step=1,
377
- label="Number of frames to process"
378
- )
379
-
380
- submit_btn = gr.Button("Stylize Video", variant="primary")
381
-
382
- with gr.Column(scale=2):
383
- # Output column
384
- video_output = gr.Video(label="Stylized Video")
385
- status_output = gr.Textbox(label="Status", value="Ready. Upload a video to start.")
386
-
387
- submit_btn.click(
388
- fn=stylize_video,
389
- inputs=[video_input, style_prompt, num_frames_slider],
390
- outputs=[video_output, status_output]
391
- )
392
-
393
- gr.Markdown("""
394
- ## Instructions
395
- 1. Upload a video up to 15 seconds long
396
- 2. Customize the style prompt if desired
397
- 3. Adjust the number of frames to process (fewer = faster)
398
- 4. Click "Stylize Video" and wait for processing
399
-
400
- ## Example Style Prompts
401
- - "Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"
402
- - "Studio Ghibli style with magical and dreamy atmosphere"
403
- - "Nostalgic Studio Ghibli animation style with watercolor backgrounds and clean linework"
404
- - "Ghibli-inspired animation with vibrant colors and fantasy elements"
405
-
406
- Note: Each frame is analyzed by LLaVA-1.5-7B and then transformed by Stable Diffusion (Ghibli-Diffusion model).
407
- Videos are processed at 1 frame per second to keep processing time reasonable.
408
-
409
- ## Technical Details
410
- - Image Analysis: Using LLaVA-1.5-7B for frame understanding and description
411
- - Image Generation: Using Stable Diffusion (nitrosocke/Ghibli-Diffusion) for style transfer
412
- - All processing happens locally - no API keys needed!
413
- """)
414
 
 
415
  if __name__ == "__main__":
416
- iface.launch()
 
 
 
 
1
  import os
2
+ import io
3
+ import json
4
+ import base64
 
 
5
  import time
 
6
  import torch
 
 
 
 
7
  from PIL import Image
8
+ from typing import Optional
9
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
10
+ from fastapi.responses import Response
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+
13
+ from safetensors.torch import save_file
14
+ from src.pipeline import FluxPipeline
15
+ from src.transformer_flux import FluxTransformer2DModel
16
+ from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
17
+
18
+ # Define paths
19
+ base_path = "black-forest-labs/FLUX.1-dev"
20
+ lora_base_path = "./models"
21
+
22
+ # Initialize the model
23
+ print("Loading model...")
24
+ pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
25
+ transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
26
+ pipe.transformer = transformer
27
+ pipe.to("cuda")
28
+ print("Model loaded successfully!")
29
+
30
+ # Function to clear cache
31
+ def clear_cache(transformer):
32
+ for name, attn_processor in transformer.attn_processors.items():
33
+ attn_processor.bank_kv.clear()
34
+
35
+ # Create FastAPI app
36
+ app = FastAPI(title="Ghibli Image Generator API",
37
+ description="Convert images to Ghibli Studio style using EasyControl")
38
+
39
+ # Add CORS middleware
40
+ app.add_middleware(
41
+ CORSMiddleware,
42
+ allow_origins=["*"],
43
+ allow_credentials=True,
44
+ allow_methods=["*"],
45
+ allow_headers=["*"],
46
+ )
47
+
48
+ # Health check endpoint
49
+ @app.get("/health")
50
+ async def health_check():
51
+ return {"status": "healthy", "model": "loaded"}
52
+
53
+ # Main image conversion endpoint
54
+ @app.post("/generate-ghibli")
55
+ async def generate_ghibli(
56
+ file: UploadFile = File(...),
57
+ prompt: str = Form("Ghibli Studio style, Charming hand-drawn anime-style illustration"),
58
+ height: int = Form(768),
59
+ width: int = Form(768),
60
+ seed: int = Form(42)
61
+ ):
62
  try:
63
+ # Validate input image
64
+ if not file.content_type.startswith("image/"):
65
+ raise HTTPException(status_code=400, detail="File must be an image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # Read and validate image
68
+ image_data = await file.read()
69
  try:
70
+ spatial_img = Image.open(io.BytesIO(image_data))
 
 
 
 
 
 
 
 
 
 
71
  except Exception as e:
72
+ raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}")
73
+
74
+ # Validate dimensions
75
+ if height < 256 or height > 1024 or width < 256 or width > 1024:
76
+ raise HTTPException(status_code=400, detail="Dimensions must be between 256 and 1024")
77
+
78
+ # Configure LoRA
79
+ lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
80
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
81
+
82
+ # Generate image
83
+ with torch.cuda.amp.autocast():
84
+ output = pipe(
85
+ prompt,
86
+ height=height,
87
+ width=width,
88
+ guidance_scale=3.5,
89
+ num_inference_steps=25,
90
+ max_sequence_length=512,
91
+ generator=torch.Generator("cpu").manual_seed(seed),
92
+ subject_images=[],
93
+ spatial_images=[spatial_img],
94
+ cond_size=512,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  ).images[0]
96
 
97
+ # Clear cache
98
+ clear_cache(pipe.transformer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Convert output to bytes
101
+ img_byte_arr = io.BytesIO()
102
+ output.save(img_byte_arr, format='PNG')
103
+ img_byte_arr.seek(0)
104
 
105
+ # Return the image directly
106
+ return Response(
107
+ content=img_byte_arr.getvalue(),
108
+ media_type="image/png"
109
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ except HTTPException as e:
112
+ raise e
113
  except Exception as e:
114
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ # Run the API with uvicorn
117
  if __name__ == "__main__":
118
+ import uvicorn
119
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)