AlekseyCalvin commited on
Commit
fac9566
·
verified ·
1 Parent(s): fd47cf1

Create app2.py

Browse files
Files changed (1) hide show
  1. app2.py +374 -0
app2.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import spaces
4
+ import torch
5
+ import random
6
+ import json
7
+ import os
8
+ from PIL import Image
9
+ from diffusers import FluxKontextPipeline
10
+ from diffusers.utils import load_image
11
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
12
+ from safetensors.torch import load_file
13
+ import requests
14
+ import re
15
+
16
+ # Load Kontext model
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+
19
+ pipe = FluxKontextPipeline.from_pretrained("LPX55/FLUX.1_Kontext-Lightning", torch_dtype=torch.bfloat16).to("cuda")
20
+
21
+ # Load LoRA data (you'll need to create this JSON file or modify to load your LoRAs)
22
+
23
+ with open("flux_loras.json", "r") as file:
24
+ data = json.load(file)
25
+ flux_loras_raw = [
26
+ {
27
+ "image": item["image"],
28
+ "title": item["title"],
29
+ "repo": item["repo"],
30
+ "trigger_word": item.get("trigger_word", ""),
31
+ "trigger_position": item.get("trigger_position", "prepend"),
32
+ "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
33
+ }
34
+ for item in data
35
+ ]
36
+ print(f"Loaded {len(flux_loras_raw)} LoRAs from JSON")
37
+ # Global variables for LoRA management
38
+ current_lora = None
39
+ lora_cache = {}
40
+
41
+ def load_lora_weights(repo_id, weights_filename):
42
+ """Load LoRA weights from HuggingFace"""
43
+ try:
44
+ if repo_id not in lora_cache:
45
+ lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
46
+ lora_cache[repo_id] = lora_path
47
+ return lora_cache[repo_id]
48
+ except Exception as e:
49
+ print(f"Error loading LoRA from {repo_id}: {e}")
50
+ return None
51
+
52
+ def update_selection(selected_state: gr.SelectData, flux_loras):
53
+ """Update UI when a LoRA is selected"""
54
+ if selected_state.index >= len(flux_loras):
55
+ return "### No LoRA selected", gr.update(), None
56
+
57
+ lora_repo = flux_loras[selected_state.index]["repo"]
58
+ trigger_word = flux_loras[selected_state.index]["trigger_word"]
59
+
60
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
61
+ new_placeholder = f"optional description, e.g. 'a man with glasses and a beard'"
62
+
63
+ return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
64
+
65
+ def get_huggingface_lora(link):
66
+ """Download LoRA from HuggingFace link"""
67
+ split_link = link.split("/")
68
+ if len(split_link) == 2:
69
+ try:
70
+ model_card = ModelCard.load(link)
71
+ trigger_word = model_card.data.get("instance_prompt", "")
72
+
73
+ fs = HfFileSystem()
74
+ list_of_files = fs.ls(link, detail=False)
75
+ safetensors_file = None
76
+
77
+ for file in list_of_files:
78
+ if file.endswith(".safetensors") and "lora" in file.lower():
79
+ safetensors_file = file.split("/")[-1]
80
+ break
81
+
82
+ if not safetensors_file:
83
+ safetensors_file = "pytorch_lora_weights.safetensors"
84
+
85
+ return split_link[1], safetensors_file, trigger_word
86
+ except Exception as e:
87
+ raise Exception(f"Error loading LoRA: {e}")
88
+ else:
89
+ raise Exception("Invalid HuggingFace repository format")
90
+
91
+ def load_custom_lora(link):
92
+ """Load custom LoRA from user input"""
93
+ if not link:
94
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on a LoRA in the gallery to select it", None
95
+
96
+ try:
97
+ repo_name, weights_file, trigger_word = get_huggingface_lora(link)
98
+
99
+ card = f'''
100
+ <div style="border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 10px 0;">
101
+ <span><strong>Loaded custom LoRA:</strong></span>
102
+ <div style="margin-top: 8px;">
103
+ <h4>{repo_name}</h4>
104
+ <small>{"Using: <code><b>"+trigger_word+"</b></code> as trigger word" if trigger_word else "No trigger word found"}</small>
105
+ </div>
106
+ </div>
107
+ '''
108
+
109
+ custom_lora_data = {
110
+ "repo": link,
111
+ "weights": weights_file,
112
+ "trigger_word": trigger_word
113
+ }
114
+
115
+ return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None
116
+
117
+ except Exception as e:
118
+ return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on a LoRA in the gallery to select it", None
119
+
120
+ def remove_custom_lora():
121
+ """Remove custom LoRA"""
122
+ return "", gr.update(visible=False), gr.update(visible=False), None, None
123
+
124
+ def classify_gallery(flux_loras):
125
+ """Sort gallery by likes"""
126
+ sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
127
+ return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
128
+
129
+ def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.75, width=960, height=1280, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
130
+ """Wrapper function to handle state serialization"""
131
+ return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, flux_loras, progress)
132
+
133
+ @spaces.GPU
134
+ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.0, width=960, height=1280, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
135
+ """Generate image with selected LoRA"""
136
+ global current_lora, pipe
137
+
138
+ if randomize_seed:
139
+ seed = random.randint(0, MAX_SEED)
140
+
141
+ # Determine which LoRA to use
142
+ lora_to_use = None
143
+ if custom_lora:
144
+ lora_to_use = custom_lora
145
+ elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
146
+ lora_to_use = flux_loras[selected_index]
147
+ print(f"Loaded {len(flux_loras)} LoRAs from JSON")
148
+ # Load LoRA if needed
149
+ if lora_to_use and lora_to_use != current_lora:
150
+ try:
151
+ # Unload current LoRA
152
+ if current_lora:
153
+ pipe.unload_lora_weights()
154
+
155
+ # Load new LoRA
156
+ lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
157
+ if lora_path:
158
+ pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
159
+ pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
160
+ print(f"loaded: {lora_path} with scale {lora_scale}")
161
+ current_lora = lora_to_use
162
+
163
+ except Exception as e:
164
+ print(f"Error loading LoRA: {e}")
165
+ # Continue without LoRA
166
+ else:
167
+ print(f"using already loaded lora: {lora_to_use}")
168
+
169
+ input_image = input_image.convert("RGB")
170
+ # Add trigger word to prompt
171
+ trigger_word = lora_to_use["trigger_word"]
172
+ if trigger_word == ", How2Draw":
173
+ prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
174
+ elif trigger_word == "__ ":
175
+ prompt = f" {prompt}. Accurately render the toolimpact logo and any tool impact iconography. The toolimpact logo begins with a two-line-tall drop-cap capital letter T with a dot in the center of its top bar."
176
+ else:
177
+ prompt = f" {prompt}. convert the style of this photo or image to {trigger_word}. Maintain the facial identity of any persons and the general features of the image!"
178
+
179
+ try:
180
+ image = pipe(
181
+ image=input_image,
182
+ prompt=prompt,
183
+ guidance_scale=guidance_scale,
184
+ num_inference_steps=steps,
185
+ generator=torch.Generator().manual_seed(seed),
186
+ width=width,
187
+ height=height,
188
+ max_area=width * height
189
+ ).images[0]
190
+
191
+ return image, seed, gr.update(visible=True)
192
+
193
+ except Exception as e:
194
+ print(f"Error during inference: {e}")
195
+ return None, seed, gr.update(visible=False)
196
+
197
+ # CSS styling
198
+ css = """
199
+ #main_app {
200
+ display: flex;
201
+ gap: 20px;
202
+ }
203
+ #box_column {
204
+ min-width: 400px;
205
+ }
206
+ #selected_lora {
207
+ color: #2563eb;
208
+ font-weight: bold;
209
+ }
210
+ #prompt {
211
+ flex-grow: 1;
212
+ }
213
+ #run_button {
214
+ background: linear-gradient(45deg, #2563eb, #3b82f6);
215
+ color: white;
216
+ border: none;
217
+ padding: 8px 16px;
218
+ border-radius: 6px;
219
+ font-weight: bold;
220
+ }
221
+ .custom_lora_card {
222
+ background: #f8fafc;
223
+ border: 1px solid #e2e8f0;
224
+ border-radius: 8px;
225
+ padding: 12px;
226
+ margin: 8px 0;
227
+ }
228
+ #gallery{
229
+ overflow: scroll !important
230
+ }
231
+ """
232
+
233
+ # Create Gradio interface
234
+ with gr.Blocks(css=css) as demo:
235
+ gr_flux_loras = gr.State(value=flux_loras_raw)
236
+
237
+ title = gr.HTML(
238
+ """<h1> Fast FLUX.1 Kontext w/LoRAs by Silver Age Poets & SOON®
239
+ <br><small style="font-size: 13px; opacity: 0.75;">Edit images w/our trained adapters as style templates! Only 8 steps! </small></h1>""",
240
+ )
241
+
242
+ selected_state = gr.State(value=None)
243
+ custom_loaded_lora = gr.State(value=None)
244
+
245
+ with gr.Row(elem_id="main_app"):
246
+ with gr.Column(scale=4, elem_id="box_column"):
247
+ with gr.Group(elem_id="gallery_box"):
248
+ input_image = gr.Image(label="Upload a picture", type="pil", height=300)
249
+
250
+ gallery = gr.Gallery(
251
+ label="Pick a LoRA",
252
+ allow_preview=False,
253
+ columns=3,
254
+ elem_id="gallery",
255
+ show_share_button=False,
256
+ height=400
257
+ )
258
+
259
+ custom_model = gr.Textbox(
260
+ label="Or enter a custom HuggingFace FLUX LoRA",
261
+ placeholder="e.g., username/lora-name",
262
+ visible=True
263
+ )
264
+ custom_model_card = gr.HTML(visible=False)
265
+ custom_model_button = gr.Button("Remove custom LoRA", visible=True)
266
+
267
+ with gr.Column(scale=5):
268
+ with gr.Row():
269
+ prompt = gr.Textbox(
270
+ label="Editing Prompt",
271
+ show_label=False,
272
+ lines=1,
273
+ max_lines=1,
274
+ placeholder="optional description, e.g. 'colorize and stylize, leave all else as is'",
275
+ elem_id="prompt"
276
+ )
277
+ run_button = gr.Button("Generate", elem_id="run_button")
278
+
279
+ result = gr.Image(label="Generated Image", interactive=False)
280
+ reuse_button = gr.Button("Reuse this image", visible=False)
281
+
282
+ with gr.Accordion("Advanced Settings", open=True):
283
+ lora_scale = gr.Slider(
284
+ label="LoRA Scale",
285
+ minimum=0,
286
+ maximum=2,
287
+ step=0.1,
288
+ value=1.5,
289
+ info="Controls the strength of the LoRA effect"
290
+ )
291
+ seed = gr.Slider(
292
+ label="Seed",
293
+ minimum=0,
294
+ maximum=MAX_SEED,
295
+ step=1,
296
+ value=0,
297
+ )
298
+ steps = gr.Slider(
299
+ label="Steps",
300
+ minimum=1,
301
+ maximum=40,
302
+ value=10,
303
+ step=1
304
+ )
305
+ width = gr.Slider(
306
+ label="Width",
307
+ minimum=128,
308
+ maximum=2560,
309
+ step=1,
310
+ value=960,
311
+ )
312
+ height = gr.Slider(
313
+ label="Height",
314
+ minimum=128,
315
+ maximum=2560,
316
+ step=1,
317
+ value=1280,
318
+ )
319
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
320
+ guidance_scale = gr.Slider(
321
+ label="Guidance Scale",
322
+ minimum=1,
323
+ maximum=10,
324
+ step=0.1,
325
+ value=2.8,
326
+ )
327
+
328
+ prompt_title = gr.Markdown(
329
+ value="### Click on a LoRA in the gallery to select it",
330
+ visible=True,
331
+ elem_id="selected_lora",
332
+ )
333
+
334
+ # Event handlers
335
+ custom_model.input(
336
+ fn=load_custom_lora,
337
+ inputs=[custom_model],
338
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
339
+ )
340
+
341
+ custom_model_button.click(
342
+ fn=remove_custom_lora,
343
+ outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
344
+ )
345
+
346
+ gallery.select(
347
+ fn=update_selection,
348
+ inputs=[gr_flux_loras],
349
+ outputs=[prompt_title, prompt, selected_state],
350
+ show_progress=False
351
+ )
352
+
353
+ gr.on(
354
+ triggers=[run_button.click, prompt.submit],
355
+ fn=infer_with_lora_wrapper,
356
+ inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, gr_flux_loras],
357
+ outputs=[result, seed, reuse_button]
358
+ )
359
+
360
+ reuse_button.click(
361
+ fn=lambda image: image,
362
+ inputs=[result],
363
+ outputs=[input_image]
364
+ )
365
+
366
+ # Initialize gallery
367
+ demo.load(
368
+ fn=classify_gallery,
369
+ inputs=[gr_flux_loras],
370
+ outputs=[gallery, gr_flux_loras]
371
+ )
372
+
373
+ demo.queue(default_concurrency_limit=None)
374
+ demo.launch()