sethchitty commited on
Commit
550894c
Β·
verified Β·
1 Parent(s): e00e806

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ from diffusers import AutoPipelineForText2Image, DDIMScheduler
5
+ from transformers import CLIPVisionModelWithProjection
6
+ from diffusers.utils import load_image
7
+ from PIL import Image
8
+ import os
9
+ import json
10
+ import gc
11
+ import traceback
12
+
13
+ STYLE_MAP = {
14
+ "pixar": [
15
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img0.png",
16
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img1.png",
17
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img2.png",
18
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img3.png",
19
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img4.png"
20
+ ]
21
+ }
22
+
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ print(f"πŸš€ Device: {device}, torch_dtype: {torch_dtype}")
26
+
27
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
28
+ "h94/IP-Adapter",
29
+ subfolder="models/image_encoder",
30
+ torch_dtype=torch_dtype,
31
+ )
32
+
33
+ pipeline = AutoPipelineForText2Image.from_pretrained(
34
+ "stabilityai/stable-diffusion-xl-base-1.0",
35
+ torch_dtype=torch_dtype,
36
+ image_encoder=image_encoder,
37
+ variant="fp16" if torch.cuda.is_available() else None
38
+ ).to(device)
39
+
40
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
41
+ pipeline.load_ip_adapter(
42
+ "h94/IP-Adapter",
43
+ subfolder="sdxl_models",
44
+ weight_name=[
45
+ "ip-adapter-plus_sdxl_vit-h.safetensors",
46
+ "ip-adapter-plus-face_sdxl_vit-h.safetensors"
47
+ ]
48
+ )
49
+ pipeline.set_ip_adapter_scale([0.7, 0.3])
50
+ pipeline.enable_model_cpu_offload()
51
+ pipeline.enable_vae_tiling()
52
+
53
+ def generate_single_scene(data):
54
+ print("πŸ“₯ Full input received:")
55
+ print(json.dumps(data, indent=2))
56
+
57
+ try:
58
+ character_image_url = data["character_image_url"]
59
+ style = data["style"]
60
+ scene_prompt = data["scene"]
61
+
62
+ print("πŸ”„ Loading reference and style images...")
63
+ face_image = load_image(character_image_url)
64
+ style_images = [load_image(url) for url in STYLE_MAP.get(style, [])]
65
+
66
+ torch.cuda.empty_cache()
67
+ gc.collect()
68
+
69
+ print("🎨 Starting generation...")
70
+ result = pipeline(
71
+ prompt=scene_prompt,
72
+ ip_adapter_image=[style_images, face_image],
73
+ negative_prompt="blurry, bad anatomy, low quality",
74
+ width=512,
75
+ height=768,
76
+ guidance_scale=5.0,
77
+ num_inference_steps=15,
78
+ generator=torch.Generator(device).manual_seed(42)
79
+ )
80
+
81
+ image = result.images[0] if hasattr(result, "images") else result
82
+ print(f"πŸ–ΌοΈ Image generated. Type: {type(image)}")
83
+
84
+ if isinstance(image, Image.Image):
85
+ print("βœ… Valid image object returned.")
86
+ return image
87
+ else:
88
+ print("❌ Invalid image object. Returning fallback.")
89
+ return Image.open("/mnt/data/error_image.png")
90
+
91
+ except Exception as e:
92
+ print(f"❌ Exception occurred: {e}")
93
+ traceback.print_exc()
94
+ return Image.open("/mnt/data/error_image.png")
95
+
96
+ def generate_from_json(json_input_text):
97
+ try:
98
+ data = json.loads(json_input_text)
99
+ return generate_single_scene(data)
100
+ except Exception as e:
101
+ print(f"❌ JSON parsing error: {e}")
102
+ traceback.print_exc()
103
+ return Image.open("/mnt/data/error_image.png")
104
+
105
+ iface = gr.Interface(
106
+ fn=generate_from_json,
107
+ inputs=gr.Textbox(label="Input JSON", lines=10),
108
+ outputs=gr.Image(label="Generated Scene or Error"),
109
+ title="Debug Storybook Scene Generator",
110
+ description="Displays logs and returns fallback image on error."
111
+ )
112
+
113
+ iface.queue().launch(share=True)