Spaces:
Running
on
Zero
Running
on
Zero
rm lazy-loading : preloaded at startup
Browse files
app.py
CHANGED
@@ -1,15 +1,13 @@
|
|
1 |
import spaces
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
-
from PIL import Image
|
5 |
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
|
6 |
import re
|
7 |
|
8 |
-
# Check if CUDA is available
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
11 |
|
12 |
-
# Check if Flash Attention 2 is available
|
13 |
def is_flash_attention_available():
|
14 |
try:
|
15 |
import flash_attn
|
@@ -17,63 +15,33 @@ def is_flash_attention_available():
|
|
17 |
except ImportError:
|
18 |
return False
|
19 |
|
20 |
-
#
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
if is_flash_attention_available():
|
37 |
-
model_kwargs["attn_implementation"] = "flash_attention_2"
|
38 |
-
|
39 |
-
base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
|
40 |
-
base_repo,
|
41 |
-
**model_kwargs
|
42 |
-
)
|
43 |
-
base_processor = AutoProcessor.from_pretrained(base_repo)
|
44 |
-
return base_model, base_processor
|
45 |
-
|
46 |
-
def load_chat_model():
|
47 |
-
global chat_model, chat_processor
|
48 |
-
if chat_model is None:
|
49 |
-
chat_repo = "microsoft/kosmos-2.5-chat"
|
50 |
-
|
51 |
-
# Use Flash Attention 2 if available, otherwise use default attention
|
52 |
-
model_kwargs = {
|
53 |
-
"device_map": "cuda",
|
54 |
-
"dtype": dtype,
|
55 |
-
}
|
56 |
-
if is_flash_attention_available():
|
57 |
-
model_kwargs["attn_implementation"] = "flash_attention_2"
|
58 |
-
|
59 |
-
chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
|
60 |
-
chat_repo,
|
61 |
-
**model_kwargs
|
62 |
-
)
|
63 |
-
chat_processor = AutoProcessor.from_pretrained(chat_repo)
|
64 |
-
return chat_model, chat_processor
|
65 |
|
66 |
def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
|
67 |
y = y.replace(prompt, "")
|
68 |
if "<md>" in prompt:
|
69 |
return y
|
70 |
-
|
71 |
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
|
72 |
bboxs_raw = re.findall(pattern, y)
|
73 |
lines = re.split(pattern, y)[1:]
|
74 |
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
|
75 |
bboxs = [[int(j) for j in i] for i in bboxs]
|
76 |
-
|
77 |
info = ""
|
78 |
for i in range(len(lines)):
|
79 |
if i < len(bboxs):
|
@@ -91,65 +59,58 @@ def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
|
|
91 |
def generate_markdown(image):
|
92 |
if image is None:
|
93 |
return "Please upload an image."
|
94 |
-
|
95 |
-
model, processor = load_base_model()
|
96 |
-
|
97 |
prompt = "<md>"
|
98 |
-
inputs =
|
99 |
-
|
100 |
height, width = inputs.pop("height"), inputs.pop("width")
|
101 |
raw_width, raw_height = image.size
|
102 |
scale_height = raw_height / height
|
103 |
scale_width = raw_width / width
|
104 |
-
|
105 |
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
|
106 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
107 |
-
|
108 |
with torch.no_grad():
|
109 |
-
generated_ids =
|
110 |
**inputs,
|
111 |
max_new_tokens=1024,
|
112 |
)
|
113 |
-
|
114 |
-
generated_text =
|
115 |
result = generated_text[0].replace(prompt, "").strip()
|
116 |
-
|
117 |
return result
|
118 |
|
119 |
@spaces.GPU(duration=120)
|
120 |
def generate_ocr(image):
|
121 |
if image is None:
|
122 |
return "Please upload an image.", None
|
123 |
-
|
124 |
-
model, processor = load_base_model()
|
125 |
-
|
126 |
prompt = "<ocr>"
|
127 |
-
inputs =
|
128 |
-
|
129 |
height, width = inputs.pop("height"), inputs.pop("width")
|
130 |
raw_width, raw_height = image.size
|
131 |
scale_height = raw_height / height
|
132 |
scale_width = raw_width / width
|
133 |
-
|
134 |
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
|
135 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
136 |
-
|
137 |
with torch.no_grad():
|
138 |
-
generated_ids =
|
139 |
**inputs,
|
140 |
max_new_tokens=1024,
|
141 |
)
|
142 |
-
|
143 |
-
generated_text =
|
144 |
-
|
145 |
-
# Post-process OCR output
|
146 |
output_text = post_process_ocr(generated_text[0], scale_height, scale_width)
|
147 |
-
|
148 |
-
# Create visualization
|
149 |
-
from PIL import ImageDraw
|
150 |
vis_image = image.copy()
|
151 |
draw = ImageDraw.Draw(vis_image)
|
152 |
-
|
153 |
lines = output_text.split("\n")
|
154 |
for line in lines:
|
155 |
if not line.strip():
|
@@ -161,7 +122,7 @@ def generate_ocr(image):
|
|
161 |
draw.polygon(coords, outline="red", width=2)
|
162 |
except:
|
163 |
continue
|
164 |
-
|
165 |
return output_text, vis_image
|
166 |
|
167 |
@spaces.GPU(duration=120)
|
@@ -170,54 +131,49 @@ def generate_chat_response(image, question):
|
|
170 |
return "Please upload an image."
|
171 |
if not question.strip():
|
172 |
return "Please ask a question."
|
173 |
-
|
174 |
-
model, processor = load_chat_model()
|
175 |
-
|
176 |
template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
|
177 |
prompt = template.format(question)
|
178 |
-
|
179 |
-
inputs =
|
180 |
-
|
181 |
height, width = inputs.pop("height"), inputs.pop("width")
|
182 |
raw_width, raw_height = image.size
|
183 |
scale_height = raw_height / height
|
184 |
scale_width = raw_width / width
|
185 |
-
|
186 |
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
|
187 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
188 |
-
|
189 |
with torch.no_grad():
|
190 |
-
generated_ids =
|
191 |
**inputs,
|
192 |
max_new_tokens=1024,
|
193 |
)
|
194 |
-
|
195 |
-
generated_text =
|
196 |
-
|
197 |
-
# Extract only the assistant's response
|
198 |
result = generated_text[0]
|
199 |
if "ASSISTANT:" in result:
|
200 |
result = result.split("ASSISTANT:")[-1].strip()
|
201 |
-
|
202 |
return result
|
203 |
|
204 |
-
# Create Gradio interface
|
205 |
with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo:
|
206 |
gr.Markdown("""
|
207 |
# KOSMOS-2.5 Document AI Demo
|
208 |
-
|
209 |
Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images!
|
210 |
This demo showcases three capabilities:
|
211 |
-
|
212 |
1. **Markdown Generation**: Convert document images to markdown format
|
213 |
2. **OCR with Bounding Boxes**: Extract text with spatial coordinates
|
214 |
3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat
|
215 |
-
|
216 |
Upload a document image (receipt, form, article, etc.) and try different tasks!
|
217 |
""")
|
218 |
-
|
219 |
with gr.Tabs():
|
220 |
-
# Markdown Generation Tab
|
221 |
with gr.TabItem("π Markdown Generation"):
|
222 |
with gr.Row():
|
223 |
with gr.Column():
|
@@ -229,13 +185,12 @@ with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as d
|
|
229 |
md_button = gr.Button("Generate Markdown", variant="primary")
|
230 |
with gr.Column():
|
231 |
md_output = gr.Textbox(
|
232 |
-
label="Generated Markdown",
|
233 |
-
lines=15,
|
234 |
max_lines=20,
|
235 |
show_copy_button=True
|
236 |
)
|
237 |
-
|
238 |
-
# OCR Tab
|
239 |
with gr.TabItem("π OCR with Bounding Boxes"):
|
240 |
with gr.Row():
|
241 |
with gr.Column():
|
@@ -248,13 +203,12 @@ with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as d
|
|
248 |
with gr.Column():
|
249 |
with gr.Row():
|
250 |
ocr_text = gr.Textbox(
|
251 |
-
label="Extracted Text with Coordinates",
|
252 |
lines=10,
|
253 |
show_copy_button=True
|
254 |
)
|
255 |
ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)")
|
256 |
-
|
257 |
-
# Chat Tab
|
258 |
with gr.TabItem("π¬ Document Q&A (Chat)"):
|
259 |
with gr.Row():
|
260 |
with gr.Column():
|
@@ -275,38 +229,22 @@ with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as d
|
|
275 |
chat_button = gr.Button("Get Answer", variant="primary")
|
276 |
with gr.Column():
|
277 |
chat_output = gr.Textbox(
|
278 |
-
label="Answer",
|
279 |
lines=8,
|
280 |
show_copy_button=True
|
281 |
)
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
outputs=[md_output]
|
288 |
-
)
|
289 |
-
|
290 |
-
ocr_button.click(
|
291 |
-
fn=generate_ocr,
|
292 |
-
inputs=[ocr_image],
|
293 |
-
outputs=[ocr_text, ocr_vis]
|
294 |
-
)
|
295 |
-
|
296 |
-
chat_button.click(
|
297 |
-
fn=generate_chat_response,
|
298 |
-
inputs=[chat_image, chat_question],
|
299 |
-
outputs=[chat_output]
|
300 |
-
)
|
301 |
-
|
302 |
-
# Examples section
|
303 |
gr.Markdown("""
|
304 |
## Example Use Cases:
|
305 |
- **Receipts**: Extract itemized information or ask about totals
|
306 |
- **Forms**: Convert to structured format or answer specific questions
|
307 |
- **Articles**: Get markdown format or ask about content
|
308 |
- **Screenshots**: Extract text or get information about specific elements
|
309 |
-
|
310 |
## Note:
|
311 |
This is a generative model and may occasionally hallucinate. Results should be verified for accuracy.
|
312 |
""")
|
|
|
1 |
import spaces
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
+
from PIL import Image, ImageDraw
|
5 |
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
|
6 |
import re
|
7 |
|
|
|
8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
10 |
|
|
|
11 |
def is_flash_attention_available():
|
12 |
try:
|
13 |
import flash_attn
|
|
|
15 |
except ImportError:
|
16 |
return False
|
17 |
|
18 |
+
# Load models once at startup
|
19 |
+
base_repo = "microsoft/kosmos-2.5"
|
20 |
+
chat_repo = "microsoft/kosmos-2.5-chat"
|
21 |
+
|
22 |
+
model_kwargs = {
|
23 |
+
"device_map": "cuda",
|
24 |
+
"dtype": dtype,
|
25 |
+
}
|
26 |
+
if is_flash_attention_available():
|
27 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
28 |
+
|
29 |
+
base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(base_repo, **model_kwargs)
|
30 |
+
base_processor = AutoProcessor.from_pretrained(base_repo)
|
31 |
+
|
32 |
+
chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(chat_repo, **model_kwargs)
|
33 |
+
chat_processor = AutoProcessor.from_pretrained(chat_repo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
|
36 |
y = y.replace(prompt, "")
|
37 |
if "<md>" in prompt:
|
38 |
return y
|
|
|
39 |
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
|
40 |
bboxs_raw = re.findall(pattern, y)
|
41 |
lines = re.split(pattern, y)[1:]
|
42 |
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
|
43 |
bboxs = [[int(j) for j in i] for i in bboxs]
|
44 |
+
|
45 |
info = ""
|
46 |
for i in range(len(lines)):
|
47 |
if i < len(bboxs):
|
|
|
59 |
def generate_markdown(image):
|
60 |
if image is None:
|
61 |
return "Please upload an image."
|
62 |
+
|
|
|
|
|
63 |
prompt = "<md>"
|
64 |
+
inputs = base_processor(text=prompt, images=image, return_tensors="pt")
|
65 |
+
|
66 |
height, width = inputs.pop("height"), inputs.pop("width")
|
67 |
raw_width, raw_height = image.size
|
68 |
scale_height = raw_height / height
|
69 |
scale_width = raw_width / width
|
70 |
+
|
71 |
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
|
72 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
73 |
+
|
74 |
with torch.no_grad():
|
75 |
+
generated_ids = base_model.generate(
|
76 |
**inputs,
|
77 |
max_new_tokens=1024,
|
78 |
)
|
79 |
+
|
80 |
+
generated_text = base_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
81 |
result = generated_text[0].replace(prompt, "").strip()
|
82 |
+
|
83 |
return result
|
84 |
|
85 |
@spaces.GPU(duration=120)
|
86 |
def generate_ocr(image):
|
87 |
if image is None:
|
88 |
return "Please upload an image.", None
|
89 |
+
|
|
|
|
|
90 |
prompt = "<ocr>"
|
91 |
+
inputs = base_processor(text=prompt, images=image, return_tensors="pt")
|
92 |
+
|
93 |
height, width = inputs.pop("height"), inputs.pop("width")
|
94 |
raw_width, raw_height = image.size
|
95 |
scale_height = raw_height / height
|
96 |
scale_width = raw_width / width
|
97 |
+
|
98 |
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
|
99 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
100 |
+
|
101 |
with torch.no_grad():
|
102 |
+
generated_ids = base_model.generate(
|
103 |
**inputs,
|
104 |
max_new_tokens=1024,
|
105 |
)
|
106 |
+
|
107 |
+
generated_text = base_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
108 |
+
|
|
|
109 |
output_text = post_process_ocr(generated_text[0], scale_height, scale_width)
|
110 |
+
|
|
|
|
|
111 |
vis_image = image.copy()
|
112 |
draw = ImageDraw.Draw(vis_image)
|
113 |
+
|
114 |
lines = output_text.split("\n")
|
115 |
for line in lines:
|
116 |
if not line.strip():
|
|
|
122 |
draw.polygon(coords, outline="red", width=2)
|
123 |
except:
|
124 |
continue
|
125 |
+
|
126 |
return output_text, vis_image
|
127 |
|
128 |
@spaces.GPU(duration=120)
|
|
|
131 |
return "Please upload an image."
|
132 |
if not question.strip():
|
133 |
return "Please ask a question."
|
134 |
+
|
|
|
|
|
135 |
template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
|
136 |
prompt = template.format(question)
|
137 |
+
|
138 |
+
inputs = chat_processor(text=prompt, images=image, return_tensors="pt")
|
139 |
+
|
140 |
height, width = inputs.pop("height"), inputs.pop("width")
|
141 |
raw_width, raw_height = image.size
|
142 |
scale_height = raw_height / height
|
143 |
scale_width = raw_width / width
|
144 |
+
|
145 |
inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
|
146 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
147 |
+
|
148 |
with torch.no_grad():
|
149 |
+
generated_ids = chat_model.generate(
|
150 |
**inputs,
|
151 |
max_new_tokens=1024,
|
152 |
)
|
153 |
+
|
154 |
+
generated_text = chat_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
155 |
+
|
|
|
156 |
result = generated_text[0]
|
157 |
if "ASSISTANT:" in result:
|
158 |
result = result.split("ASSISTANT:")[-1].strip()
|
159 |
+
|
160 |
return result
|
161 |
|
|
|
162 |
with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo:
|
163 |
gr.Markdown("""
|
164 |
# KOSMOS-2.5 Document AI Demo
|
165 |
+
|
166 |
Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images!
|
167 |
This demo showcases three capabilities:
|
168 |
+
|
169 |
1. **Markdown Generation**: Convert document images to markdown format
|
170 |
2. **OCR with Bounding Boxes**: Extract text with spatial coordinates
|
171 |
3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat
|
172 |
+
|
173 |
Upload a document image (receipt, form, article, etc.) and try different tasks!
|
174 |
""")
|
175 |
+
|
176 |
with gr.Tabs():
|
|
|
177 |
with gr.TabItem("π Markdown Generation"):
|
178 |
with gr.Row():
|
179 |
with gr.Column():
|
|
|
185 |
md_button = gr.Button("Generate Markdown", variant="primary")
|
186 |
with gr.Column():
|
187 |
md_output = gr.Textbox(
|
188 |
+
label="Generated Markdown",
|
189 |
+
lines=15,
|
190 |
max_lines=20,
|
191 |
show_copy_button=True
|
192 |
)
|
193 |
+
|
|
|
194 |
with gr.TabItem("π OCR with Bounding Boxes"):
|
195 |
with gr.Row():
|
196 |
with gr.Column():
|
|
|
203 |
with gr.Column():
|
204 |
with gr.Row():
|
205 |
ocr_text = gr.Textbox(
|
206 |
+
label="Extracted Text with Coordinates",
|
207 |
lines=10,
|
208 |
show_copy_button=True
|
209 |
)
|
210 |
ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)")
|
211 |
+
|
|
|
212 |
with gr.TabItem("π¬ Document Q&A (Chat)"):
|
213 |
with gr.Row():
|
214 |
with gr.Column():
|
|
|
229 |
chat_button = gr.Button("Get Answer", variant="primary")
|
230 |
with gr.Column():
|
231 |
chat_output = gr.Textbox(
|
232 |
+
label="Answer",
|
233 |
lines=8,
|
234 |
show_copy_button=True
|
235 |
)
|
236 |
+
|
237 |
+
md_button.click(fn=generate_markdown, inputs=[md_image], outputs=[md_output])
|
238 |
+
ocr_button.click(fn=generate_ocr, inputs=[ocr_image], outputs=[ocr_text, ocr_vis])
|
239 |
+
chat_button.click(fn=generate_chat_response, inputs=[chat_image, chat_question], outputs=[chat_output])
|
240 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
gr.Markdown("""
|
242 |
## Example Use Cases:
|
243 |
- **Receipts**: Extract itemized information or ask about totals
|
244 |
- **Forms**: Convert to structured format or answer specific questions
|
245 |
- **Articles**: Get markdown format or ask about content
|
246 |
- **Screenshots**: Extract text or get information about specific elements
|
247 |
+
|
248 |
## Note:
|
249 |
This is a generative model and may occasionally hallucinate. Results should be verified for accuracy.
|
250 |
""")
|