Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,416 +1,119 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import ffmpeg
|
3 |
import os
|
4 |
-
import
|
5 |
-
import
|
6 |
-
import
|
7 |
-
import shutil
|
8 |
-
import re
|
9 |
import time
|
10 |
-
import concurrent.futures
|
11 |
import torch
|
12 |
-
from pathlib import Path
|
13 |
-
from dotenv import load_dotenv
|
14 |
-
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
|
15 |
-
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
16 |
from PIL import Image
|
17 |
-
import
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
try:
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
print(f"Warning: Could not create directory {directory}: {e}")
|
41 |
-
return False
|
42 |
-
|
43 |
-
# Create necessary directories
|
44 |
-
for directory in [MODEL_DIR, os.environ["TRANSFORMERS_CACHE"],
|
45 |
-
os.environ["HF_HOME"], os.environ["HUGGINGFACE_HUB_CACHE"]]:
|
46 |
-
safe_makedirs(directory)
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
# Add GPU decorator for Hugging Face Spaces
|
53 |
-
try:
|
54 |
-
from spaces import GPU
|
55 |
-
use_gpu = True
|
56 |
-
@GPU
|
57 |
-
def get_gpu():
|
58 |
-
return True
|
59 |
-
# Call the function to trigger GPU allocation
|
60 |
-
get_gpu()
|
61 |
-
except ImportError:
|
62 |
-
use_gpu = False
|
63 |
-
print("Running without GPU acceleration")
|
64 |
-
|
65 |
-
# Load environment variables from .env file if it exists
|
66 |
-
load_dotenv()
|
67 |
-
|
68 |
-
# Global variables to hold models (lazy loading)
|
69 |
-
llava_model = None
|
70 |
-
llava_processor = None
|
71 |
-
stable_diffusion_pipeline = None
|
72 |
-
|
73 |
-
# Set up the model directory
|
74 |
-
MODEL_DIR = "./model"
|
75 |
-
os.makedirs(MODEL_DIR, exist_ok=True)
|
76 |
-
|
77 |
-
def load_llava_model():
|
78 |
-
"""Load LLaVA model for image captioning"""
|
79 |
-
global llava_model, llava_processor
|
80 |
-
|
81 |
-
if llava_model is None or llava_processor is None:
|
82 |
-
print("Loading LLaVA model for image analysis...")
|
83 |
-
model_id = "llava-hf/llava-1.5-7b-hf"
|
84 |
|
|
|
|
|
85 |
try:
|
86 |
-
|
87 |
-
llava_processor = AutoProcessor.from_pretrained(
|
88 |
-
model_id,
|
89 |
-
local_files_only=False
|
90 |
-
)
|
91 |
-
llava_model = AutoModelForCausalLM.from_pretrained(
|
92 |
-
model_id,
|
93 |
-
torch_dtype=torch.float16,
|
94 |
-
device_map="auto",
|
95 |
-
local_files_only=False
|
96 |
-
)
|
97 |
except Exception as e:
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
#
|
105 |
-
os.
|
106 |
-
|
107 |
-
|
108 |
-
#
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
# Load the pipeline with precision to balance memory usage and quality
|
123 |
-
stable_diffusion_pipeline = StableDiffusionPipeline.from_pretrained(
|
124 |
-
model_id,
|
125 |
-
torch_dtype=torch.float16,
|
126 |
-
safety_checker=None, # Disable safety checker for performance
|
127 |
-
cache_dir=os.path.join(MODEL_DIR, "stable_diffusion")
|
128 |
-
)
|
129 |
-
|
130 |
-
|
131 |
-
# Move to GPU if available
|
132 |
-
if torch.cuda.is_available():
|
133 |
-
stable_diffusion_pipeline = stable_diffusion_pipeline.to("cuda")
|
134 |
-
|
135 |
-
# Use the DPM-Solver++ scheduler for better quality at lower steps
|
136 |
-
stable_diffusion_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
137 |
-
stable_diffusion_pipeline.scheduler.config,
|
138 |
-
algorithm_type="dpmsolver++",
|
139 |
-
use_karras_sigmas=True
|
140 |
-
)
|
141 |
-
|
142 |
-
return stable_diffusion_pipeline
|
143 |
-
|
144 |
-
def analyze_image_with_llava(image_path):
|
145 |
-
"""Process a single frame with LLaVA to generate a description"""
|
146 |
-
try:
|
147 |
-
# Load the model if not already loaded
|
148 |
-
model, processor = load_llava_model()
|
149 |
-
|
150 |
-
# Load the image
|
151 |
-
image = Image.open(image_path)
|
152 |
-
|
153 |
-
# Prompt for Ghibli-specific description
|
154 |
-
prompt = "Describe this image in detail, focusing on elements that would be important to recreate it in Studio Ghibli animation style."
|
155 |
-
|
156 |
-
# Process the image and generate text
|
157 |
-
inputs = processor(prompt, image, return_tensors="pt").to(model.device)
|
158 |
-
|
159 |
-
# Generate with appropriate parameters
|
160 |
-
with torch.no_grad():
|
161 |
-
output = model.generate(
|
162 |
-
**inputs,
|
163 |
-
max_new_tokens=300,
|
164 |
-
do_sample=True,
|
165 |
-
temperature=0.7,
|
166 |
-
top_p=0.9,
|
167 |
-
)
|
168 |
-
|
169 |
-
# Decode the output
|
170 |
-
generated_text = processor.decode(output[0], skip_special_tokens=True)
|
171 |
-
|
172 |
-
# Extract just the response part (remove the prompt)
|
173 |
-
response = generated_text.split(prompt)[-1].strip()
|
174 |
-
print(f"LLaVA analysis for frame {os.path.basename(image_path)}: {response[:150]}...")
|
175 |
-
|
176 |
-
return response
|
177 |
-
|
178 |
-
except Exception as e:
|
179 |
-
import traceback
|
180 |
-
print(f"Error analyzing image {os.path.basename(image_path)}: {str(e)}")
|
181 |
-
print(traceback.format_exc())
|
182 |
-
return f"Error analyzing image: {str(e)}"
|
183 |
-
|
184 |
-
def generate_ghibli_image(image_description, style_prompt, output_path):
|
185 |
-
"""Generate a Ghibli-style image based on the description using Stable Diffusion"""
|
186 |
-
try:
|
187 |
-
# Load the model if not already loaded
|
188 |
-
pipeline = load_stable_diffusion_model()
|
189 |
-
|
190 |
-
# Combine the image description with the style prompt
|
191 |
-
full_prompt = f"{image_description}. {style_prompt}. Hand-drawn animation style, soft colors, attention to detail, Miyazaki aesthetic."
|
192 |
-
|
193 |
-
# Ensure prompt isn't too long
|
194 |
-
if len(full_prompt) > 500:
|
195 |
-
full_prompt = full_prompt[:497] + "..."
|
196 |
-
|
197 |
-
# Generate the image
|
198 |
-
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
|
199 |
-
image = pipeline(
|
200 |
-
prompt=full_prompt,
|
201 |
-
negative_prompt="3d, cgi, low quality, blurry, distorted, deformed",
|
202 |
-
num_inference_steps=30,
|
203 |
-
guidance_scale=7.5,
|
204 |
-
width=768,
|
205 |
-
height=768,
|
206 |
).images[0]
|
207 |
|
208 |
-
#
|
209 |
-
|
210 |
-
print(f"Successfully saved stylized frame: {os.path.basename(output_path)}")
|
211 |
-
return True
|
212 |
-
|
213 |
-
except Exception as e:
|
214 |
-
import traceback
|
215 |
-
print(f"Error generating image: {str(e)}")
|
216 |
-
print(traceback.format_exc())
|
217 |
-
return False
|
218 |
-
|
219 |
-
def process_frame(frame_path, style_prompt):
|
220 |
-
"""Process a single frame with LLaVA analysis and Stable Diffusion generation"""
|
221 |
-
try:
|
222 |
-
# First use LLaVA to analyze the image
|
223 |
-
image_description = analyze_image_with_llava(frame_path)
|
224 |
-
|
225 |
-
if image_description.startswith("Error"):
|
226 |
-
return False
|
227 |
|
228 |
-
#
|
229 |
-
|
|
|
|
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
print(traceback.format_exc())
|
237 |
-
return False
|
238 |
-
|
239 |
-
def stylize_video(video_path, style_prompt, num_frames=15):
|
240 |
-
try:
|
241 |
-
# Create temp directories
|
242 |
-
temp_dir = tempfile.mkdtemp()
|
243 |
-
input_filename = os.path.join(temp_dir, "input.mp4")
|
244 |
-
frames_dir = os.path.join(temp_dir, "frames")
|
245 |
-
os.makedirs(frames_dir, exist_ok=True)
|
246 |
-
|
247 |
-
# Save the input video to a temporary file
|
248 |
-
if isinstance(video_path, str):
|
249 |
-
if video_path.startswith('http'):
|
250 |
-
# It's a URL, download it
|
251 |
-
response = requests.get(video_path, stream=True)
|
252 |
-
with open(input_filename, 'wb') as f:
|
253 |
-
for chunk in response.iter_content(chunk_size=8192):
|
254 |
-
f.write(chunk)
|
255 |
-
elif os.path.exists(video_path):
|
256 |
-
# It's a file path, copy it
|
257 |
-
shutil.copy(video_path, input_filename)
|
258 |
-
else:
|
259 |
-
return None, f"Video file not found: {video_path}"
|
260 |
-
else:
|
261 |
-
# Assume it's binary data
|
262 |
-
with open(input_filename, "wb") as f:
|
263 |
-
f.write(video_path)
|
264 |
-
|
265 |
-
# Make sure the video file exists
|
266 |
-
if not os.path.exists(input_filename):
|
267 |
-
return None, "Failed to save input video"
|
268 |
-
|
269 |
-
# Extract frames - using lower fps for longer videos (1 frame per second)
|
270 |
-
ffmpeg.input(input_filename).output(f"{frames_dir}/%04d.png", vf="fps=1").run(quiet=True)
|
271 |
-
|
272 |
-
# Check if frames were extracted
|
273 |
-
frames = sorted([os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith('.png')])
|
274 |
-
if not frames:
|
275 |
-
return None, "No frames were extracted from the video"
|
276 |
-
|
277 |
-
# Limit to a maximum number of frames for reasonable processing times
|
278 |
-
if len(frames) > num_frames:
|
279 |
-
# Take evenly distributed frames
|
280 |
-
indices = [int(i * (len(frames) - 1) / (num_frames - 1)) for i in range(num_frames)]
|
281 |
-
frames = [frames[i] for i in indices]
|
282 |
-
|
283 |
-
print(f"Processing {len(frames)} frames")
|
284 |
-
|
285 |
-
# Process frames sequentially if we're using a GPU (to avoid CUDA OOM errors)
|
286 |
-
# Otherwise, use a modest level of parallelism
|
287 |
-
if torch.cuda.is_available():
|
288 |
-
# Sequential processing to avoid CUDA OOM errors
|
289 |
-
processed_frames = []
|
290 |
-
for i, frame in enumerate(frames):
|
291 |
-
success = process_frame(frame, style_prompt)
|
292 |
-
if success:
|
293 |
-
processed_frames.append(frame)
|
294 |
-
print(f"Completed frame {os.path.basename(frame)} ({i+1}/{len(frames)})")
|
295 |
-
else:
|
296 |
-
print(f"Failed to process frame {os.path.basename(frame)}")
|
297 |
-
|
298 |
-
# Free up CUDA cache between frames
|
299 |
-
torch.cuda.empty_cache()
|
300 |
-
else:
|
301 |
-
# Process frames in parallel with limited workers if no GPU
|
302 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
303 |
-
futures = {executor.submit(process_frame, frame, style_prompt): frame for frame in frames}
|
304 |
-
|
305 |
-
# Collect results
|
306 |
-
processed_frames = []
|
307 |
-
for future in concurrent.futures.as_completed(futures):
|
308 |
-
frame = futures[future]
|
309 |
-
if future.result():
|
310 |
-
processed_frames.append(frame)
|
311 |
-
print(f"Completed frame {os.path.basename(frame)} ({len(processed_frames)}/{len(frames)})")
|
312 |
-
|
313 |
-
if not processed_frames:
|
314 |
-
return None, "Failed to process any frames. Please check the logs for more information."
|
315 |
-
|
316 |
-
# Even if not all frames were processed, try to create a video with what we have
|
317 |
-
print(f"Successfully processed {len(processed_frames)}/{len(frames)} frames")
|
318 |
-
|
319 |
-
# Ensure frames are in the correct order (important for video continuity)
|
320 |
-
processed_frames.sort()
|
321 |
-
|
322 |
-
# Reassemble frames into video
|
323 |
-
output_filename = os.path.join(temp_dir, "stylized.mp4")
|
324 |
-
|
325 |
-
# Use a higher bitrate and better codec for higher quality
|
326 |
-
ffmpeg.input(f"{frames_dir}/%04d.png", framerate=1) \
|
327 |
-
.output(output_filename, vcodec='libx264', pix_fmt='yuv420p', crf=18) \
|
328 |
-
.run(quiet=True)
|
329 |
-
|
330 |
-
# Check if the output file exists and has content
|
331 |
-
if not os.path.exists(output_filename) or os.path.getsize(output_filename) == 0:
|
332 |
-
return None, "Failed to create output video"
|
333 |
-
|
334 |
-
# Copy to a persistent location for Gradio to serve
|
335 |
-
os.makedirs("outputs", exist_ok=True)
|
336 |
-
persistent_output = os.path.join("outputs", f"stylized_{uuid.uuid4()}.mp4")
|
337 |
-
shutil.copy(output_filename, persistent_output)
|
338 |
-
|
339 |
-
# Return the relative path (Gradio can handle this)
|
340 |
-
print(f"Output video created at: {persistent_output}")
|
341 |
-
|
342 |
-
# Cleanup temp files
|
343 |
-
shutil.rmtree(temp_dir)
|
344 |
-
|
345 |
-
return persistent_output, f"Video stylized successfully with {len(processed_frames)} frames!"
|
346 |
|
|
|
|
|
347 |
except Exception as e:
|
348 |
-
|
349 |
-
traceback_str = traceback.format_exc()
|
350 |
-
print(f"Error: {str(e)}\n{traceback_str}")
|
351 |
-
return None, f"Error: {str(e)}"
|
352 |
-
|
353 |
-
# Use Gradio examples feature with local files
|
354 |
-
example_videos = [
|
355 |
-
["sample_video.mp4", "Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"]
|
356 |
-
]
|
357 |
-
|
358 |
-
with gr.Blocks(title="Video-to-Ghibli Style Converter (Open Source)") as iface:
|
359 |
-
gr.Markdown("# Video-to-Ghibli Style Converter (Open Source)")
|
360 |
-
gr.Markdown("Upload a video and convert it to Studio Ghibli animation style using LLaVA and Stable Diffusion.")
|
361 |
-
|
362 |
-
with gr.Row():
|
363 |
-
with gr.Column(scale=2):
|
364 |
-
# Main input column
|
365 |
-
video_input = gr.Video(label="Upload Video (up to 15 seconds)")
|
366 |
-
|
367 |
-
style_prompt = gr.Textbox(
|
368 |
-
label="Style Prompt",
|
369 |
-
value="Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"
|
370 |
-
)
|
371 |
-
|
372 |
-
num_frames_slider = gr.Slider(
|
373 |
-
minimum=5,
|
374 |
-
maximum=15,
|
375 |
-
value=10,
|
376 |
-
step=1,
|
377 |
-
label="Number of frames to process"
|
378 |
-
)
|
379 |
-
|
380 |
-
submit_btn = gr.Button("Stylize Video", variant="primary")
|
381 |
-
|
382 |
-
with gr.Column(scale=2):
|
383 |
-
# Output column
|
384 |
-
video_output = gr.Video(label="Stylized Video")
|
385 |
-
status_output = gr.Textbox(label="Status", value="Ready. Upload a video to start.")
|
386 |
-
|
387 |
-
submit_btn.click(
|
388 |
-
fn=stylize_video,
|
389 |
-
inputs=[video_input, style_prompt, num_frames_slider],
|
390 |
-
outputs=[video_output, status_output]
|
391 |
-
)
|
392 |
-
|
393 |
-
gr.Markdown("""
|
394 |
-
## Instructions
|
395 |
-
1. Upload a video up to 15 seconds long
|
396 |
-
2. Customize the style prompt if desired
|
397 |
-
3. Adjust the number of frames to process (fewer = faster)
|
398 |
-
4. Click "Stylize Video" and wait for processing
|
399 |
-
|
400 |
-
## Example Style Prompts
|
401 |
-
- "Studio Ghibli animation with Hayao Miyazaki's distinctive hand-drawn art style"
|
402 |
-
- "Studio Ghibli style with magical and dreamy atmosphere"
|
403 |
-
- "Nostalgic Studio Ghibli animation style with watercolor backgrounds and clean linework"
|
404 |
-
- "Ghibli-inspired animation with vibrant colors and fantasy elements"
|
405 |
-
|
406 |
-
Note: Each frame is analyzed by LLaVA-1.5-7B and then transformed by Stable Diffusion (Ghibli-Diffusion model).
|
407 |
-
Videos are processed at 1 frame per second to keep processing time reasonable.
|
408 |
-
|
409 |
-
## Technical Details
|
410 |
-
- Image Analysis: Using LLaVA-1.5-7B for frame understanding and description
|
411 |
-
- Image Generation: Using Stable Diffusion (nitrosocke/Ghibli-Diffusion) for style transfer
|
412 |
-
- All processing happens locally - no API keys needed!
|
413 |
-
""")
|
414 |
|
|
|
415 |
if __name__ == "__main__":
|
416 |
-
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import base64
|
|
|
|
|
5 |
import time
|
|
|
6 |
import torch
|
|
|
|
|
|
|
|
|
7 |
from PIL import Image
|
8 |
+
from typing import Optional
|
9 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
10 |
+
from fastapi.responses import Response
|
11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
12 |
+
|
13 |
+
from safetensors.torch import save_file
|
14 |
+
from src.pipeline import FluxPipeline
|
15 |
+
from src.transformer_flux import FluxTransformer2DModel
|
16 |
+
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
|
17 |
+
|
18 |
+
# Define paths
|
19 |
+
base_path = "black-forest-labs/FLUX.1-dev"
|
20 |
+
lora_base_path = "./models"
|
21 |
+
|
22 |
+
# Initialize the model
|
23 |
+
print("Loading model...")
|
24 |
+
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
|
25 |
+
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
|
26 |
+
pipe.transformer = transformer
|
27 |
+
pipe.to("cuda")
|
28 |
+
print("Model loaded successfully!")
|
29 |
+
|
30 |
+
# Function to clear cache
|
31 |
+
def clear_cache(transformer):
|
32 |
+
for name, attn_processor in transformer.attn_processors.items():
|
33 |
+
attn_processor.bank_kv.clear()
|
34 |
+
|
35 |
+
# Create FastAPI app
|
36 |
+
app = FastAPI(title="Ghibli Image Generator API",
|
37 |
+
description="Convert images to Ghibli Studio style using EasyControl")
|
38 |
+
|
39 |
+
# Add CORS middleware
|
40 |
+
app.add_middleware(
|
41 |
+
CORSMiddleware,
|
42 |
+
allow_origins=["*"],
|
43 |
+
allow_credentials=True,
|
44 |
+
allow_methods=["*"],
|
45 |
+
allow_headers=["*"],
|
46 |
+
)
|
47 |
+
|
48 |
+
# Health check endpoint
|
49 |
+
@app.get("/health")
|
50 |
+
async def health_check():
|
51 |
+
return {"status": "healthy", "model": "loaded"}
|
52 |
+
|
53 |
+
# Main image conversion endpoint
|
54 |
+
@app.post("/generate-ghibli")
|
55 |
+
async def generate_ghibli(
|
56 |
+
file: UploadFile = File(...),
|
57 |
+
prompt: str = Form("Ghibli Studio style, Charming hand-drawn anime-style illustration"),
|
58 |
+
height: int = Form(768),
|
59 |
+
width: int = Form(768),
|
60 |
+
seed: int = Form(42)
|
61 |
+
):
|
62 |
try:
|
63 |
+
# Validate input image
|
64 |
+
if not file.content_type.startswith("image/"):
|
65 |
+
raise HTTPException(status_code=400, detail="File must be an image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
# Read and validate image
|
68 |
+
image_data = await file.read()
|
69 |
try:
|
70 |
+
spatial_img = Image.open(io.BytesIO(image_data))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
except Exception as e:
|
72 |
+
raise HTTPException(status_code=400, detail=f"Invalid image: {str(e)}")
|
73 |
+
|
74 |
+
# Validate dimensions
|
75 |
+
if height < 256 or height > 1024 or width < 256 or width > 1024:
|
76 |
+
raise HTTPException(status_code=400, detail="Dimensions must be between 256 and 1024")
|
77 |
+
|
78 |
+
# Configure LoRA
|
79 |
+
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
|
80 |
+
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
|
81 |
+
|
82 |
+
# Generate image
|
83 |
+
with torch.cuda.amp.autocast():
|
84 |
+
output = pipe(
|
85 |
+
prompt,
|
86 |
+
height=height,
|
87 |
+
width=width,
|
88 |
+
guidance_scale=3.5,
|
89 |
+
num_inference_steps=25,
|
90 |
+
max_sequence_length=512,
|
91 |
+
generator=torch.Generator("cpu").manual_seed(seed),
|
92 |
+
subject_images=[],
|
93 |
+
spatial_images=[spatial_img],
|
94 |
+
cond_size=512,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
).images[0]
|
96 |
|
97 |
+
# Clear cache
|
98 |
+
clear_cache(pipe.transformer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
# Convert output to bytes
|
101 |
+
img_byte_arr = io.BytesIO()
|
102 |
+
output.save(img_byte_arr, format='PNG')
|
103 |
+
img_byte_arr.seek(0)
|
104 |
|
105 |
+
# Return the image directly
|
106 |
+
return Response(
|
107 |
+
content=img_byte_arr.getvalue(),
|
108 |
+
media_type="image/png"
|
109 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
except HTTPException as e:
|
112 |
+
raise e
|
113 |
except Exception as e:
|
114 |
+
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
+
# Run the API with uvicorn
|
117 |
if __name__ == "__main__":
|
118 |
+
import uvicorn
|
119 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|