kevalfst commited on
Commit
5d10322
·
verified ·
1 Parent(s): cd3f0e3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import random
4
+ import hashlib
5
+ from diffusers import DiffusionPipeline
6
+ from transformers import pipeline
7
+ from diffusers.utils import export_to_video
8
+
9
+ # Optional: xformers optimization
10
+ try:
11
+ import xformers
12
+ has_xformers = True
13
+ except ImportError:
14
+ has_xformers = False
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
+ MAX_SEED = 2**32 - 1
19
+
20
+ # Model lists ordered by size
21
+ image_models = {
22
+ "Stable Diffusion 1.5 (light)": "runwayml/stable-diffusion-v1-5",
23
+ "Stable Diffusion 2.1": "stabilityai/stable-diffusion-2-1",
24
+ "Dreamlike 2.0": "dreamlike-art/dreamlike-photoreal-2.0",
25
+ "Playground v2": "playgroundai/playground-v2-1024px-aesthetic",
26
+ "Muse 512": "amused/muse-512-finetuned",
27
+ "PixArt": "PixArt-alpha/PixArt-LCM-XL-2-1024-MS",
28
+ "Kandinsky 3": "kandinsky-community/kandinsky-3",
29
+ "BLIP Diffusion": "Salesforce/blipdiffusion",
30
+ "SDXL Base 1.0 (heavy)": "stabilityai/stable-diffusion-xl-base-1.0",
31
+ "OpenJourney (heavy)": "prompthero/openjourney"
32
+ }
33
+
34
+ text_models = {
35
+ "GPT-2 (light)": "gpt2",
36
+ "GPT-Neo 1.3B": "EleutherAI/gpt-neo-1.3B",
37
+ "BLOOM 1.1B": "bigscience/bloom-1b1",
38
+ "GPT-J 6B": "EleutherAI/gpt-j-6B",
39
+ "Falcon 7B": "tiiuae/falcon-7b",
40
+ "XGen 7B": "Salesforce/xgen-7b-8k-base",
41
+ "BTLM 3B": "cerebras/btlm-3b-8k-base",
42
+ "MPT 7B": "mosaicml/mpt-7b",
43
+ "StableLM 2": "stabilityai/stablelm-2-1_6b",
44
+ "LLaMA 2 7B (heavy)": "meta-llama/Llama-2-7b-hf"
45
+ }
46
+
47
+ video_models = {
48
+ "CogVideoX-2B": "THUDM/CogVideoX-2b",
49
+ "CogVideoX-5B": "THUDM/CogVideoX-5b",
50
+ "AnimateDiff-Lightning": "ByteDance/AnimateDiff-Lightning",
51
+ "ModelScope T2V": "damo-vilab/text-to-video-ms-1.7b",
52
+ "VideoCrafter2": "VideoCrafter/VideoCrafter2",
53
+ "Open-Sora-Plan-v1.2.0": "LanguageBind/Open-Sora-Plan-v1.2.0",
54
+ "LTX-Video": "Lightricks/LTX-Video",
55
+ "HunyuanVideo": "tencent/HunyuanVideo",
56
+ "Latte-1": "maxin-cn/Latte-1",
57
+ "LaVie": "Vchitect/LaVie"
58
+ }
59
+
60
+ # Caches
61
+ image_pipes = {}
62
+ text_pipes = {}
63
+ video_pipes = {}
64
+ image_cache = {}
65
+ text_cache = {}
66
+ video_cache = {}
67
+
68
+ def hash_inputs(*args):
69
+ combined = "|".join(map(str, args))
70
+ return hashlib.sha256(combined.encode()).hexdigest()
71
+
72
+ def generate_image(prompt, model_name, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
73
+ if randomize_seed:
74
+ seed = random.randint(0, MAX_SEED)
75
+
76
+ key = hash_inputs(prompt, model_name, seed)
77
+ if key in image_cache:
78
+ progress(100, desc="Using cached image.")
79
+ return image_cache[key], seed
80
+
81
+ progress(10, desc="Loading model...")
82
+ if model_name not in image_pipes:
83
+ pipe = DiffusionPipeline.from_pretrained(
84
+ image_models[model_name],
85
+ torch_dtype=torch_dtype,
86
+ low_cpu_mem_usage=True
87
+ )
88
+
89
+ if torch.__version__.startswith("2"):
90
+ pipe = torch.compile(pipe)
91
+ if has_xformers and device == "cuda":
92
+ try:
93
+ pipe.enable_xformers_memory_efficient_attention()
94
+ except Exception:
95
+ pass
96
+
97
+ pipe.to(device)
98
+ image_pipes[model_name] = pipe
99
+
100
+ pipe = image_pipes[model_name]
101
+
102
+ progress(40, desc="Generating image...")
103
+ result = pipe(prompt=prompt, generator=torch.manual_seed(seed), num_inference_steps=15, width=512, height=512)
104
+ image = result.images[0]
105
+ image_cache[key] = image
106
+
107
+ progress(100, desc="Done.")
108
+ return image, seed
109
+
110
+ def generate_text(prompt, model_name, progress=gr.Progress(track_tqdm=True)):
111
+ key = hash_inputs(prompt, model_name)
112
+ if key in text_cache:
113
+ progress(100, desc="Using cached text.")
114
+ return text_cache[key]
115
+
116
+ progress(10, desc="Loading model...")
117
+ if model_name not in text_pipes:
118
+ text_pipes[model_name] = pipeline(
119
+ "text-generation",
120
+ model=text_models[model_name],
121
+ device=0 if device == "cuda" else -1
122
+ )
123
+ pipe = text_pipes[model_name]
124
+
125
+ progress(40, desc="Generating text...")
126
+ result = pipe(prompt, max_length=100, do_sample=True)[0]['generated_text']
127
+ text_cache[key] = result
128
+
129
+ progress(100, desc="Done.")
130
+ return result
131
+
132
+ def generate_video(prompt, model_name, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
133
+ if randomize_seed:
134
+ seed = random.randint(0, MAX_SEED)
135
+
136
+ key = hash_inputs(prompt, model_name, seed)
137
+ if key in video_cache:
138
+ progress(100, desc="Using cached video.")
139
+ return video_cache[key], seed
140
+
141
+ progress(10, desc="Loading model...")
142
+ if model_name not in video_pipes:
143
+ pipe = DiffusionPipeline.from_pretrained(
144
+ video_models[model_name],
145
+ torch_dtype=torch_dtype,
146
+ variant="fp16"
147
+ )
148
+
149
+ if torch.__version__.startswith("2"):
150
+ pipe = torch.compile(pipe)
151
+ if has_xformers and device == "cuda":
152
+ try:
153
+ pipe.enable_xformers_memory_efficient_attention()
154
+ except Exception:
155
+ pass
156
+
157
+ pipe.to(device)
158
+ video_pipes[model_name] = pipe
159
+
160
+ pipe = video_pipes[model_name]
161
+
162
+ progress(40, desc="Generating video...")
163
+ result = pipe(prompt=prompt, generator=torch.manual_seed(seed), num_inference_steps=15)
164
+ video_frames = result.frames[0]
165
+ video_path = export_to_video(video_frames)
166
+ video_cache[key] = video_path
167
+
168
+ progress(100, desc="Done.")
169
+ return video_path, seed
170
+
171
+ # Gradio Interface
172
+ with gr.Blocks() as demo:
173
+ gr.Markdown("# ⚡ Fast Multi-Model AI Playground with Caching")
174
+
175
+ with gr.Tabs():
176
+ # Image Generation
177
+ with gr.Tab("🖼️ Image Generation"):
178
+ img_prompt = gr.Textbox(label="Prompt")
179
+ img_model = gr.Dropdown(choices=list(image_models.keys()), value="Stable Diffusion 1.5 (light)", label="Image Model")
180
+ img_seed = gr.Slider(0, MAX_SEED, value=42, label="Seed")
181
+ img_rand = gr.Checkbox(label="Randomize seed", value=True)
182
+ img_btn = gr.Button("Generate Image")
183
+ img_out = gr.Image()
184
+ img_btn.click(fn=generate_image, inputs=[img_prompt, img_model, img_seed, img_rand], outputs=[img_out, img_seed])
185
+
186
+ # Text Generation
187
+ with gr.Tab("📝 Text Generation"):
188
+ txt_prompt = gr.Textbox(label="Prompt")
189
+ txt_model = gr.Dropdown(choices=list(text_models.keys()), value="GPT-2 (light)", label="Text Model")
190
+ txt_btn = gr.Button("Generate Text")
191
+ txt_out = gr.Textbox(label="Output Text")
192
+ txt_btn.click(fn=generate_text, inputs=[txt_prompt, txt_model], outputs=[txt_out])
193
+
194
+ # Video Generation
195
+ with gr.Tab("🎥 Video Generation"):
196
+ vid_prompt = gr.Textbox(label="Prompt")
197
+ vid_model = gr.Dropdown(choices=list(video_models.keys()), value="CogVideoX-2B", label="Video Model")
198
+ vid_seed = gr.Slider(0, MAX_SEED, value=42, label="Seed")
199
+ vid_rand = gr.Checkbox(label="Randomize seed", value=True)
200
+ vid_btn = gr.Button("Generate Video")
201
+ vid_out = gr.Video()
202
+ vid_btn.click(fn=generate_video, inputs=[vid_prompt, vid_model, vid_seed, vid_rand], outputs=[vid_out, vid_seed])
203
+