awacke1 commited on
Commit
3966f89
Β·
verified Β·
1 Parent(s): ba6ccdb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import glob
4
+ import time
5
+ import pandas as pd
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from diffusers import StableDiffusionPipeline
9
+ import fitz
10
+ import requests
11
+ from PIL import Image
12
+ import logging
13
+ import asyncio
14
+ import aiofiles
15
+ from io import BytesIO
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+ import gradio as gr
19
+
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
21
+ logger = logging.getLogger(__name__)
22
+ log_records = []
23
+
24
+ class LogCaptureHandler(logging.Handler):
25
+ def emit(self, record):
26
+ log_records.append(record)
27
+
28
+ logger.addHandler(LogCaptureHandler())
29
+
30
+ @dataclass
31
+ class ModelConfig:
32
+ name: str
33
+ base_model: str
34
+ size: str
35
+ domain: Optional[str] = None
36
+ model_type: str = "causal_lm"
37
+ @property
38
+ def model_path(self):
39
+ return f"models/{self.name}"
40
+
41
+ @dataclass
42
+ class DiffusionConfig:
43
+ name: str
44
+ base_model: str
45
+ size: str
46
+ domain: Optional[str] = None
47
+ @property
48
+ def model_path(self):
49
+ return f"diffusion_models/{self.name}"
50
+
51
+ class ModelBuilder:
52
+ def __init__(self):
53
+ self.config = None
54
+ self.model = None
55
+ self.tokenizer = None
56
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
57
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
58
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
59
+ if self.tokenizer.pad_token is None:
60
+ self.tokenizer.pad_token = self.tokenizer.eos_token
61
+ if config:
62
+ self.config = config
63
+ self.model.to("cuda" if torch.cuda.is_available() else "cpu")
64
+ return self
65
+ def save_model(self, path: str):
66
+ os.makedirs(os.path.dirname(path), exist_ok=True)
67
+ self.model.save_pretrained(path)
68
+ self.tokenizer.save_pretrained(path)
69
+
70
+ class DiffusionBuilder:
71
+ def __init__(self):
72
+ self.config = None
73
+ self.pipeline = None
74
+ def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
75
+ self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
76
+ if config:
77
+ self.config = config
78
+ return self
79
+ def save_model(self, path: str):
80
+ os.makedirs(os.path.dirname(path), exist_ok=True)
81
+ self.pipeline.save_pretrained(path)
82
+ def generate(self, prompt: str):
83
+ return self.pipeline(prompt, num_inference_steps=20).images[0]
84
+
85
+ def generate_filename(sequence, ext="png"):
86
+ timestamp = time.strftime("%d%m%Y%H%M%S")
87
+ return f"{sequence}_{timestamp}.{ext}"
88
+
89
+ def get_gallery_files(file_types):
90
+ return sorted(list(set([f for ext in file_types for f in glob.glob(f"*.{ext}")]))) # Deduplicate files
91
+
92
+ async def process_image_gen(prompt, output_file, builder):
93
+ if builder and isinstance(builder, DiffusionBuilder) and builder.pipeline:
94
+ pipeline = builder.pipeline
95
+ else:
96
+ pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cpu")
97
+ gen_image = pipeline(prompt, num_inference_steps=20).images[0]
98
+ gen_image.save(output_file)
99
+ return gen_image
100
+
101
+ # Smart Uploader Functions
102
+ def upload_files(files, links_title, links_url, history, selected_files):
103
+ uploaded = {"images": [], "videos": [], "documents": [], "datasets": [], "links": []}
104
+ if files:
105
+ for file in files:
106
+ ext = file.name.split('.')[-1].lower()
107
+ output_path = f"uploaded_{int(time.time())}_{file.name}"
108
+ with open(output_path, "wb") as f:
109
+ f.write(file.read())
110
+ if ext in ["jpg", "png"]:
111
+ uploaded["images"].append(output_path)
112
+ elif ext == "mp4":
113
+ uploaded["videos"].append(output_path)
114
+ elif ext in ["md", "pdf", "docx"]:
115
+ uploaded["documents"].append(output_path)
116
+ elif ext in ["csv", "xlsx"]:
117
+ uploaded["datasets"].append(output_path)
118
+ history.append(f"Uploaded: {output_path}")
119
+ selected_files[output_path] = False # Default unchecked
120
+ if links_title and links_url:
121
+ links = list(zip(links_title.split('\n'), links_url.split('\n')))
122
+ for title, url in links:
123
+ if title and url:
124
+ link_entry = f"[{title}]({url})"
125
+ uploaded["links"].append(link_entry)
126
+ history.append(f"Added Link: {link_entry}")
127
+ selected_files[link_entry] = False
128
+ return uploaded, history, selected_files
129
+
130
+ def update_galleries(history, selected_files):
131
+ galleries = {
132
+ "images": get_gallery_files(["jpg", "png"]),
133
+ "videos": get_gallery_files(["mp4"]),
134
+ "documents": get_gallery_files(["md", "pdf", "docx"]),
135
+ "datasets": get_gallery_files(["csv", "xlsx"]),
136
+ "links": [f for f in selected_files.keys() if f.startswith('[') and '](' in f and f.endswith(')')]
137
+ }
138
+ gallery_outputs = {
139
+ "images": [(Image.open(f), os.path.basename(f)) for f in galleries["images"][:4]],
140
+ "videos": [(f, os.path.basename(f)) for f in galleries["videos"][:4]], # Video preview as file path
141
+ "documents": [(Image.frombytes("RGB", fitz.open(f)[0].get_pixmap(matrix=fitz.Matrix(0.5, 0.5)).size, fitz.open(f)[0].get_pixmap(matrix=fitz.Matrix(0.5, 0.5)).samples) if f.endswith('.pdf') else f, os.path.basename(f)) for f in galleries["documents"][:4]],
142
+ "datasets": [(f, os.path.basename(f)) for f in galleries["datasets"][:4]], # Text preview
143
+ "links": [(f, f.split(']')[0][1:]) for f in galleries["links"][:4]]
144
+ }
145
+ history.append(f"Updated galleries: {sum(len(g) for g in galleries.values())} files")
146
+ return gallery_outputs, history, selected_files
147
+
148
+ def toggle_selection(file_list, selected_files):
149
+ for file in file_list:
150
+ selected_files[file] = not selected_files.get(file, False)
151
+ return selected_files
152
+
153
+ def image_gen(prompt, builder, history, selected_files):
154
+ selected = [f for f, sel in selected_files.items() if sel and f.endswith(('.jpg', '.png'))]
155
+ if not selected:
156
+ return "No images selected", None, history, selected_files
157
+ output_file = generate_filename("gen_output", "png")
158
+ gen_image = asyncio.run(process_image_gen(prompt, output_file, builder))
159
+ history.append(f"Image Gen: {prompt} -> {output_file}")
160
+ selected_files[output_file] = True
161
+ return f"Image saved to {output_file}", gen_image, history, selected_files
162
+
163
+ # Gradio UI
164
+ with gr.Blocks(title="AI Vision & SFT Titans πŸš€") as demo:
165
+ gr.Markdown("# AI Vision & SFT Titans πŸš€")
166
+ history = gr.State(value=[])
167
+ builder = gr.State(value=None)
168
+ selected_files = gr.State(value={})
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ gr.Markdown("## πŸ“ File Tree")
173
+ with gr.Accordion("🌳 Uploads", open=True):
174
+ with gr.Row():
175
+ gr.Markdown("### πŸ–ΌοΈ Images (jpg/png)")
176
+ img_gallery = gr.Gallery(label="Images", columns=4, height="auto")
177
+ with gr.Row():
178
+ gr.Markdown("### πŸŽ₯ Videos (mp4)")
179
+ vid_gallery = gr.Gallery(label="Videos", columns=4, height="auto")
180
+ with gr.Row():
181
+ gr.Markdown("### πŸ“œ Docs (md/pdf/docx)")
182
+ doc_gallery = gr.Gallery(label="Documents", columns=4, height="auto")
183
+ with gr.Row():
184
+ gr.Markdown("### πŸ“Š Data (csv/xlsx)")
185
+ data_gallery = gr.Gallery(label="Datasets", columns=4, height="auto")
186
+ with gr.Row():
187
+ gr.Markdown("### πŸ”— Links")
188
+ link_gallery = gr.Gallery(label="Links", columns=4, height="auto")
189
+ gr.Markdown("## πŸ“œ History")
190
+ history_output = gr.Textbox(label="Log", lines=5, interactive=False)
191
+
192
+ with gr.Column(scale=3):
193
+ with gr.Row():
194
+ gr.Markdown("## πŸ› οΈ Toolbar")
195
+ upload_btn = gr.Button("πŸ“€ Upload")
196
+ select_btn = gr.Button("βœ… Select")
197
+ gen_btn = gr.Button("🎨 Generate")
198
+
199
+ with gr.Tabs():
200
+ with gr.TabItem("πŸ“€ Smart Upload"):
201
+ file_upload = gr.File(label="Upload Files", file_count="multiple", type="binary")
202
+ links_title = gr.Textbox(label="Link Titles (one per line)", lines=3)
203
+ links_url = gr.Textbox(label="Link URLs (one per line)", lines=3)
204
+ upload_status = gr.Textbox(label="Status")
205
+
206
+ with gr.TabItem("πŸ” Operations"):
207
+ prompt = gr.Textbox(label="Image Gen Prompt", value="Generate a neon version")
208
+ op_status = gr.Textbox(label="Status")
209
+ op_output = gr.Image(label="Output")
210
+
211
+ upload_btn.click(
212
+ upload_files,
213
+ inputs=[file_upload, links_title, links_url, history, selected_files],
214
+ outputs=[upload_status, history, selected_files]
215
+ ).then(
216
+ update_galleries,
217
+ inputs=[history, selected_files],
218
+ outputs=[img_gallery, vid_gallery, doc_gallery, data_gallery, link_gallery, history, selected_files]
219
+ )
220
+
221
+ select_btn.click(
222
+ toggle_selection,
223
+ inputs=[gr.Dropdown(choices=list(selected_files.value.keys()), multiselect=True, label="Select Files"), selected_files],
224
+ outputs=[selected_files]
225
+ ).then(
226
+ update_galleries,
227
+ inputs=[history, selected_files],
228
+ outputs=[img_gallery, vid_gallery, doc_gallery, data_gallery, link_gallery, history, selected_files]
229
+ )
230
+
231
+ gen_btn.click(
232
+ image_gen,
233
+ inputs=[prompt, builder, history, selected_files],
234
+ outputs=[op_status, op_output, history, selected_files]
235
+ ).then(
236
+ update_galleries,
237
+ inputs=[history, selected_files],
238
+ outputs=[img_gallery, vid_gallery, doc_gallery, data_gallery, link_gallery, history, selected_files]
239
+ )
240
+
241
+ # Update history output
242
+ demo.load(lambda h: "\n".join(h[-5:]), inputs=[history], outputs=[history_output])
243
+
244
+ demo.launch()