sethchitty commited on
Commit
c5032e4
Β·
verified Β·
1 Parent(s): bcf0c88

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ from diffusers import AutoPipelineForText2Image, DDIMScheduler
5
+ from transformers import CLIPVisionModelWithProjection
6
+ from diffusers.utils import load_image
7
+ from PIL import Image
8
+ import os
9
+ import json
10
+ import gc
11
+ import traceback
12
+
13
+ STYLE_MAP = {
14
+ "pixar": [
15
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img0.png",
16
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img1.png",
17
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img2.png",
18
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img3.png",
19
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img4.png"
20
+ ]
21
+ }
22
+
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ print(f"πŸš€ Device: {device}, torch_dtype: {torch_dtype}")
26
+
27
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
28
+ "h94/IP-Adapter",
29
+ subfolder="models/image_encoder",
30
+ torch_dtype=torch_dtype,
31
+ )
32
+
33
+ pipeline = AutoPipelineForText2Image.from_pretrained(
34
+ "stabilityai/stable-diffusion-xl-base-1.0",
35
+ torch_dtype=torch_dtype,
36
+ image_encoder=image_encoder,
37
+ variant="fp16" if torch.cuda.is_available() else None
38
+ ).to(device)
39
+
40
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
41
+ pipeline.load_ip_adapter(
42
+ "h94/IP-Adapter",
43
+ subfolder="sdxl_models",
44
+ weight_name=[
45
+ "ip-adapter-plus_sdxl_vit-h.safetensors",
46
+ "ip-adapter-plus-face_sdxl_vit-h.safetensors"
47
+ ]
48
+ )
49
+ pipeline.set_ip_adapter_scale([0.7, 0.3])
50
+ pipeline.enable_model_cpu_offload()
51
+ pipeline.enable_vae_tiling()
52
+
53
+ def generate_single_scene(data):
54
+ print("πŸ“₯ Single scene input received:")
55
+ print(json.dumps(data, indent=2))
56
+
57
+ try:
58
+ character_image_url = data["character_image_url"]
59
+ style = data["style"]
60
+ scene_prompt = data["scene"]
61
+
62
+ face_image = load_image(character_image_url)
63
+ style_images = [load_image(url) for url in STYLE_MAP.get(style, [])]
64
+
65
+ torch.cuda.empty_cache()
66
+ gc.collect()
67
+
68
+ result = pipeline(
69
+ prompt=scene_prompt,
70
+ ip_adapter_image=[style_images, face_image],
71
+ negative_prompt="blurry, bad anatomy, low quality",
72
+ width=512,
73
+ height=768,
74
+ guidance_scale=5.0,
75
+ num_inference_steps=15,
76
+ generator=torch.Generator(device).manual_seed(42)
77
+ )
78
+
79
+ image = result.images[0] if hasattr(result, "images") else result
80
+ print(f"πŸ–ΌοΈ Image type: {type(image)}")
81
+
82
+ if isinstance(image, Image.Image):
83
+ print("βœ… Returning valid image object")
84
+ return image
85
+ else:
86
+ print("❌ Invalid image type. Not returning.")
87
+ return None
88
+
89
+ except Exception as e:
90
+ print(f"❌ Error: {e}")
91
+ traceback.print_exc()
92
+ return None
93
+
94
+ def generate_from_json(json_input_text):
95
+ try:
96
+ data = json.loads(json_input_text)
97
+ return generate_single_scene(data)
98
+ except Exception as e:
99
+ print(f"❌ JSON parse or generation error: {e}")
100
+ traceback.print_exc()
101
+ return None
102
+
103
+ iface = gr.Interface(
104
+ fn=generate_from_json,
105
+ inputs=gr.Textbox(label="Input JSON", lines=10, placeholder='{"character_image_url": "...", "style": "pixar", "scene": "..."}'),
106
+ outputs=gr.Image(label="Generated Scene"),
107
+ title="Single-Scene Storybook Generator",
108
+ description="Send one scene at a time to generate consistent character-based images."
109
+ )
110
+
111
+ iface.queue().launch(share=True)