Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from diffusers import AutoPipelineForText2Image, DDIMScheduler | |
from transformers import CLIPVisionModelWithProjection | |
from diffusers.utils import load_image | |
from PIL import Image | |
import os | |
import json | |
import gc | |
import traceback | |
STYLE_MAP = { | |
"pixar": [ | |
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img0.png", | |
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img1.png", | |
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img2.png", | |
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img3.png", | |
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img4.png" | |
] | |
} | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"π Device: {device}, torch_dtype: {torch_dtype}") | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
"h94/IP-Adapter", | |
subfolder="models/image_encoder", | |
torch_dtype=torch_dtype, | |
) | |
pipeline = AutoPipelineForText2Image.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch_dtype, | |
image_encoder=image_encoder, | |
variant="fp16" if torch.cuda.is_available() else None | |
).to(device) | |
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | |
pipeline.load_ip_adapter( | |
"h94/IP-Adapter", | |
subfolder="sdxl_models", | |
weight_name=[ | |
"ip-adapter-plus_sdxl_vit-h.safetensors", | |
"ip-adapter-plus-face_sdxl_vit-h.safetensors" | |
] | |
) | |
pipeline.set_ip_adapter_scale([0.7, 0.3]) | |
pipeline.enable_model_cpu_offload() | |
pipeline.enable_vae_tiling() | |
def generate_single_scene(data): | |
print("π₯ Full input received:") | |
print(json.dumps(data, indent=2)) | |
try: | |
character_image_url = data["character_image_url"] | |
style = data["style"] | |
scene_prompt = data["scene"] | |
print("π Loading reference and style images...") | |
face_image = load_image(character_image_url) | |
style_images = [load_image(url) for url in STYLE_MAP.get(style, [])] | |
torch.cuda.empty_cache() | |
gc.collect() | |
print("π¨ Starting generation...") | |
result = pipeline( | |
prompt=scene_prompt, | |
ip_adapter_image=[style_images, face_image], | |
negative_prompt="blurry, bad anatomy, low quality", | |
width=512, | |
height=768, | |
guidance_scale=5.0, | |
num_inference_steps=15, | |
generator=torch.Generator(device).manual_seed(42) | |
) | |
image = result.images[0] if hasattr(result, "images") else result | |
print(f"πΌοΈ Image generated. Type: {type(image)}") | |
if isinstance(image, Image.Image): | |
print("β Valid image object returned.") | |
return image | |
else: | |
print("β Invalid image object. Returning fallback.") | |
return Image.open("/mnt/data/error_image.png") | |
except Exception as e: | |
print(f"β Exception occurred: {e}") | |
traceback.print_exc() | |
return Image.open("/mnt/data/error_image.png") | |
def generate_from_json(json_input_text): | |
try: | |
data = json.loads(json_input_text) | |
return generate_single_scene(data) | |
except Exception as e: | |
print(f"β JSON parsing error: {e}") | |
traceback.print_exc() | |
return Image.open("/mnt/data/error_image.png") | |
iface = gr.Interface( | |
fn=generate_from_json, | |
inputs=gr.Textbox(label="Input JSON", lines=10), | |
outputs=gr.Image(label="Generated Scene or Error"), | |
title="Debug Storybook Scene Generator", | |
description="Displays logs and returns fallback image on error." | |
) | |
iface.queue().launch(share=True) | |