educational / app.py
sethchitty's picture
Upload app.py
550894c verified
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image
from PIL import Image
import os
import json
import gc
import traceback
STYLE_MAP = {
"pixar": [
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img0.png",
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img1.png",
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img2.png",
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img3.png",
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img4.png"
]
}
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸš€ Device: {device}, torch_dtype: {torch_dtype}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch_dtype,
)
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype,
image_encoder=image_encoder,
variant="fp16" if torch.cuda.is_available() else None
).to(device)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name=[
"ip-adapter-plus_sdxl_vit-h.safetensors",
"ip-adapter-plus-face_sdxl_vit-h.safetensors"
]
)
pipeline.set_ip_adapter_scale([0.7, 0.3])
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_tiling()
def generate_single_scene(data):
print("πŸ“₯ Full input received:")
print(json.dumps(data, indent=2))
try:
character_image_url = data["character_image_url"]
style = data["style"]
scene_prompt = data["scene"]
print("πŸ”„ Loading reference and style images...")
face_image = load_image(character_image_url)
style_images = [load_image(url) for url in STYLE_MAP.get(style, [])]
torch.cuda.empty_cache()
gc.collect()
print("🎨 Starting generation...")
result = pipeline(
prompt=scene_prompt,
ip_adapter_image=[style_images, face_image],
negative_prompt="blurry, bad anatomy, low quality",
width=512,
height=768,
guidance_scale=5.0,
num_inference_steps=15,
generator=torch.Generator(device).manual_seed(42)
)
image = result.images[0] if hasattr(result, "images") else result
print(f"πŸ–ΌοΈ Image generated. Type: {type(image)}")
if isinstance(image, Image.Image):
print("βœ… Valid image object returned.")
return image
else:
print("❌ Invalid image object. Returning fallback.")
return Image.open("/mnt/data/error_image.png")
except Exception as e:
print(f"❌ Exception occurred: {e}")
traceback.print_exc()
return Image.open("/mnt/data/error_image.png")
def generate_from_json(json_input_text):
try:
data = json.loads(json_input_text)
return generate_single_scene(data)
except Exception as e:
print(f"❌ JSON parsing error: {e}")
traceback.print_exc()
return Image.open("/mnt/data/error_image.png")
iface = gr.Interface(
fn=generate_from_json,
inputs=gr.Textbox(label="Input JSON", lines=10),
outputs=gr.Image(label="Generated Scene or Error"),
title="Debug Storybook Scene Generator",
description="Displays logs and returns fallback image on error."
)
iface.queue().launch(share=True)