File size: 3,837 Bytes
550894c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image
from PIL import Image
import os
import json
import gc
import traceback

STYLE_MAP = {
    "pixar": [
        "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img0.png",
        "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img1.png",
        "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img2.png",
        "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img3.png",
        "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img4.png"
    ]
}

torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸš€ Device: {device}, torch_dtype: {torch_dtype}")

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter",
    subfolder="models/image_encoder",
    torch_dtype=torch_dtype,
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch_dtype,
    image_encoder=image_encoder,
    variant="fp16" if torch.cuda.is_available() else None
).to(device)

pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter(
    "h94/IP-Adapter",
    subfolder="sdxl_models",
    weight_name=[
        "ip-adapter-plus_sdxl_vit-h.safetensors",
        "ip-adapter-plus-face_sdxl_vit-h.safetensors"
    ]
)
pipeline.set_ip_adapter_scale([0.7, 0.3])
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_tiling()

def generate_single_scene(data):
    print("πŸ“₯ Full input received:")
    print(json.dumps(data, indent=2))

    try:
        character_image_url = data["character_image_url"]
        style = data["style"]
        scene_prompt = data["scene"]

        print("πŸ”„ Loading reference and style images...")
        face_image = load_image(character_image_url)
        style_images = [load_image(url) for url in STYLE_MAP.get(style, [])]

        torch.cuda.empty_cache()
        gc.collect()

        print("🎨 Starting generation...")
        result = pipeline(
            prompt=scene_prompt,
            ip_adapter_image=[style_images, face_image],
            negative_prompt="blurry, bad anatomy, low quality",
            width=512,
            height=768,
            guidance_scale=5.0,
            num_inference_steps=15,
            generator=torch.Generator(device).manual_seed(42)
        )

        image = result.images[0] if hasattr(result, "images") else result
        print(f"πŸ–ΌοΈ Image generated. Type: {type(image)}")

        if isinstance(image, Image.Image):
            print("βœ… Valid image object returned.")
            return image
        else:
            print("❌ Invalid image object. Returning fallback.")
            return Image.open("/mnt/data/error_image.png")

    except Exception as e:
        print(f"❌ Exception occurred: {e}")
        traceback.print_exc()
        return Image.open("/mnt/data/error_image.png")

def generate_from_json(json_input_text):
    try:
        data = json.loads(json_input_text)
        return generate_single_scene(data)
    except Exception as e:
        print(f"❌ JSON parsing error: {e}")
        traceback.print_exc()
        return Image.open("/mnt/data/error_image.png")

iface = gr.Interface(
    fn=generate_from_json,
    inputs=gr.Textbox(label="Input JSON", lines=10),
    outputs=gr.Image(label="Generated Scene or Error"),
    title="Debug Storybook Scene Generator",
    description="Displays logs and returns fallback image on error."
)

iface.queue().launch(share=True)