prithivMLmods commited on
Commit
f2254d0
Β·
verified Β·
1 Parent(s): dce53e9

rm lazy-loading : preloaded at startup

Browse files
Files changed (1) hide show
  1. app.py +66 -128
app.py CHANGED
@@ -1,15 +1,13 @@
1
  import spaces
2
  import torch
3
  import gradio as gr
4
- from PIL import Image
5
  from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
6
  import re
7
 
8
- # Check if CUDA is available
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
11
 
12
- # Check if Flash Attention 2 is available
13
  def is_flash_attention_available():
14
  try:
15
  import flash_attn
@@ -17,63 +15,33 @@ def is_flash_attention_available():
17
  except ImportError:
18
  return False
19
 
20
- # Initialize models and processors lazily
21
- base_model = None
22
- base_processor = None
23
- chat_model = None
24
- chat_processor = None
25
-
26
- def load_base_model():
27
- global base_model, base_processor
28
- if base_model is None:
29
- base_repo = "microsoft/kosmos-2.5"
30
-
31
- # Use Flash Attention 2 if available, otherwise use default attention
32
- model_kwargs = {
33
- "device_map": "cuda",
34
- "dtype": dtype,
35
- }
36
- if is_flash_attention_available():
37
- model_kwargs["attn_implementation"] = "flash_attention_2"
38
-
39
- base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
40
- base_repo,
41
- **model_kwargs
42
- )
43
- base_processor = AutoProcessor.from_pretrained(base_repo)
44
- return base_model, base_processor
45
-
46
- def load_chat_model():
47
- global chat_model, chat_processor
48
- if chat_model is None:
49
- chat_repo = "microsoft/kosmos-2.5-chat"
50
-
51
- # Use Flash Attention 2 if available, otherwise use default attention
52
- model_kwargs = {
53
- "device_map": "cuda",
54
- "dtype": dtype,
55
- }
56
- if is_flash_attention_available():
57
- model_kwargs["attn_implementation"] = "flash_attention_2"
58
-
59
- chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
60
- chat_repo,
61
- **model_kwargs
62
- )
63
- chat_processor = AutoProcessor.from_pretrained(chat_repo)
64
- return chat_model, chat_processor
65
 
66
  def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
67
  y = y.replace(prompt, "")
68
  if "<md>" in prompt:
69
  return y
70
-
71
  pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
72
  bboxs_raw = re.findall(pattern, y)
73
  lines = re.split(pattern, y)[1:]
74
  bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
75
  bboxs = [[int(j) for j in i] for i in bboxs]
76
-
77
  info = ""
78
  for i in range(len(lines)):
79
  if i < len(bboxs):
@@ -91,65 +59,58 @@ def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
91
  def generate_markdown(image):
92
  if image is None:
93
  return "Please upload an image."
94
-
95
- model, processor = load_base_model()
96
-
97
  prompt = "<md>"
98
- inputs = processor(text=prompt, images=image, return_tensors="pt")
99
-
100
  height, width = inputs.pop("height"), inputs.pop("width")
101
  raw_width, raw_height = image.size
102
  scale_height = raw_height / height
103
  scale_width = raw_width / width
104
-
105
  inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
106
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
107
-
108
  with torch.no_grad():
109
- generated_ids = model.generate(
110
  **inputs,
111
  max_new_tokens=1024,
112
  )
113
-
114
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
115
  result = generated_text[0].replace(prompt, "").strip()
116
-
117
  return result
118
 
119
  @spaces.GPU(duration=120)
120
  def generate_ocr(image):
121
  if image is None:
122
  return "Please upload an image.", None
123
-
124
- model, processor = load_base_model()
125
-
126
  prompt = "<ocr>"
127
- inputs = processor(text=prompt, images=image, return_tensors="pt")
128
-
129
  height, width = inputs.pop("height"), inputs.pop("width")
130
  raw_width, raw_height = image.size
131
  scale_height = raw_height / height
132
  scale_width = raw_width / width
133
-
134
  inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
135
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
136
-
137
  with torch.no_grad():
138
- generated_ids = model.generate(
139
  **inputs,
140
  max_new_tokens=1024,
141
  )
142
-
143
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
144
-
145
- # Post-process OCR output
146
  output_text = post_process_ocr(generated_text[0], scale_height, scale_width)
147
-
148
- # Create visualization
149
- from PIL import ImageDraw
150
  vis_image = image.copy()
151
  draw = ImageDraw.Draw(vis_image)
152
-
153
  lines = output_text.split("\n")
154
  for line in lines:
155
  if not line.strip():
@@ -161,7 +122,7 @@ def generate_ocr(image):
161
  draw.polygon(coords, outline="red", width=2)
162
  except:
163
  continue
164
-
165
  return output_text, vis_image
166
 
167
  @spaces.GPU(duration=120)
@@ -170,54 +131,49 @@ def generate_chat_response(image, question):
170
  return "Please upload an image."
171
  if not question.strip():
172
  return "Please ask a question."
173
-
174
- model, processor = load_chat_model()
175
-
176
  template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
177
  prompt = template.format(question)
178
-
179
- inputs = processor(text=prompt, images=image, return_tensors="pt")
180
-
181
  height, width = inputs.pop("height"), inputs.pop("width")
182
  raw_width, raw_height = image.size
183
  scale_height = raw_height / height
184
  scale_width = raw_width / width
185
-
186
  inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
187
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
188
-
189
  with torch.no_grad():
190
- generated_ids = model.generate(
191
  **inputs,
192
  max_new_tokens=1024,
193
  )
194
-
195
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
196
-
197
- # Extract only the assistant's response
198
  result = generated_text[0]
199
  if "ASSISTANT:" in result:
200
  result = result.split("ASSISTANT:")[-1].strip()
201
-
202
  return result
203
 
204
- # Create Gradio interface
205
  with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo:
206
  gr.Markdown("""
207
  # KOSMOS-2.5 Document AI Demo
208
-
209
  Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images!
210
  This demo showcases three capabilities:
211
-
212
  1. **Markdown Generation**: Convert document images to markdown format
213
  2. **OCR with Bounding Boxes**: Extract text with spatial coordinates
214
  3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat
215
-
216
  Upload a document image (receipt, form, article, etc.) and try different tasks!
217
  """)
218
-
219
  with gr.Tabs():
220
- # Markdown Generation Tab
221
  with gr.TabItem("πŸ“ Markdown Generation"):
222
  with gr.Row():
223
  with gr.Column():
@@ -229,13 +185,12 @@ with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as d
229
  md_button = gr.Button("Generate Markdown", variant="primary")
230
  with gr.Column():
231
  md_output = gr.Textbox(
232
- label="Generated Markdown",
233
- lines=15,
234
  max_lines=20,
235
  show_copy_button=True
236
  )
237
-
238
- # OCR Tab
239
  with gr.TabItem("πŸ” OCR with Bounding Boxes"):
240
  with gr.Row():
241
  with gr.Column():
@@ -248,13 +203,12 @@ with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as d
248
  with gr.Column():
249
  with gr.Row():
250
  ocr_text = gr.Textbox(
251
- label="Extracted Text with Coordinates",
252
  lines=10,
253
  show_copy_button=True
254
  )
255
  ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)")
256
-
257
- # Chat Tab
258
  with gr.TabItem("πŸ’¬ Document Q&A (Chat)"):
259
  with gr.Row():
260
  with gr.Column():
@@ -275,38 +229,22 @@ with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as d
275
  chat_button = gr.Button("Get Answer", variant="primary")
276
  with gr.Column():
277
  chat_output = gr.Textbox(
278
- label="Answer",
279
  lines=8,
280
  show_copy_button=True
281
  )
282
-
283
- # Event handlers
284
- md_button.click(
285
- fn=generate_markdown,
286
- inputs=[md_image],
287
- outputs=[md_output]
288
- )
289
-
290
- ocr_button.click(
291
- fn=generate_ocr,
292
- inputs=[ocr_image],
293
- outputs=[ocr_text, ocr_vis]
294
- )
295
-
296
- chat_button.click(
297
- fn=generate_chat_response,
298
- inputs=[chat_image, chat_question],
299
- outputs=[chat_output]
300
- )
301
-
302
- # Examples section
303
  gr.Markdown("""
304
  ## Example Use Cases:
305
  - **Receipts**: Extract itemized information or ask about totals
306
  - **Forms**: Convert to structured format or answer specific questions
307
  - **Articles**: Get markdown format or ask about content
308
  - **Screenshots**: Extract text or get information about specific elements
309
-
310
  ## Note:
311
  This is a generative model and may occasionally hallucinate. Results should be verified for accuracy.
312
  """)
 
1
  import spaces
2
  import torch
3
  import gradio as gr
4
+ from PIL import Image, ImageDraw
5
  from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
6
  import re
7
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
10
 
 
11
  def is_flash_attention_available():
12
  try:
13
  import flash_attn
 
15
  except ImportError:
16
  return False
17
 
18
+ # Load models once at startup
19
+ base_repo = "microsoft/kosmos-2.5"
20
+ chat_repo = "microsoft/kosmos-2.5-chat"
21
+
22
+ model_kwargs = {
23
+ "device_map": "cuda",
24
+ "dtype": dtype,
25
+ }
26
+ if is_flash_attention_available():
27
+ model_kwargs["attn_implementation"] = "flash_attention_2"
28
+
29
+ base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(base_repo, **model_kwargs)
30
+ base_processor = AutoProcessor.from_pretrained(base_repo)
31
+
32
+ chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(chat_repo, **model_kwargs)
33
+ chat_processor = AutoProcessor.from_pretrained(chat_repo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
36
  y = y.replace(prompt, "")
37
  if "<md>" in prompt:
38
  return y
 
39
  pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
40
  bboxs_raw = re.findall(pattern, y)
41
  lines = re.split(pattern, y)[1:]
42
  bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
43
  bboxs = [[int(j) for j in i] for i in bboxs]
44
+
45
  info = ""
46
  for i in range(len(lines)):
47
  if i < len(bboxs):
 
59
  def generate_markdown(image):
60
  if image is None:
61
  return "Please upload an image."
62
+
 
 
63
  prompt = "<md>"
64
+ inputs = base_processor(text=prompt, images=image, return_tensors="pt")
65
+
66
  height, width = inputs.pop("height"), inputs.pop("width")
67
  raw_width, raw_height = image.size
68
  scale_height = raw_height / height
69
  scale_width = raw_width / width
70
+
71
  inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
72
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
73
+
74
  with torch.no_grad():
75
+ generated_ids = base_model.generate(
76
  **inputs,
77
  max_new_tokens=1024,
78
  )
79
+
80
+ generated_text = base_processor.batch_decode(generated_ids, skip_special_tokens=True)
81
  result = generated_text[0].replace(prompt, "").strip()
82
+
83
  return result
84
 
85
  @spaces.GPU(duration=120)
86
  def generate_ocr(image):
87
  if image is None:
88
  return "Please upload an image.", None
89
+
 
 
90
  prompt = "<ocr>"
91
+ inputs = base_processor(text=prompt, images=image, return_tensors="pt")
92
+
93
  height, width = inputs.pop("height"), inputs.pop("width")
94
  raw_width, raw_height = image.size
95
  scale_height = raw_height / height
96
  scale_width = raw_width / width
97
+
98
  inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
99
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
100
+
101
  with torch.no_grad():
102
+ generated_ids = base_model.generate(
103
  **inputs,
104
  max_new_tokens=1024,
105
  )
106
+
107
+ generated_text = base_processor.batch_decode(generated_ids, skip_special_tokens=True)
108
+
 
109
  output_text = post_process_ocr(generated_text[0], scale_height, scale_width)
110
+
 
 
111
  vis_image = image.copy()
112
  draw = ImageDraw.Draw(vis_image)
113
+
114
  lines = output_text.split("\n")
115
  for line in lines:
116
  if not line.strip():
 
122
  draw.polygon(coords, outline="red", width=2)
123
  except:
124
  continue
125
+
126
  return output_text, vis_image
127
 
128
  @spaces.GPU(duration=120)
 
131
  return "Please upload an image."
132
  if not question.strip():
133
  return "Please ask a question."
134
+
 
 
135
  template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
136
  prompt = template.format(question)
137
+
138
+ inputs = chat_processor(text=prompt, images=image, return_tensors="pt")
139
+
140
  height, width = inputs.pop("height"), inputs.pop("width")
141
  raw_width, raw_height = image.size
142
  scale_height = raw_height / height
143
  scale_width = raw_width / width
144
+
145
  inputs = {k: v.to("cuda") if v is not None else None for k, v in inputs.items()}
146
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
147
+
148
  with torch.no_grad():
149
+ generated_ids = chat_model.generate(
150
  **inputs,
151
  max_new_tokens=1024,
152
  )
153
+
154
+ generated_text = chat_processor.batch_decode(generated_ids, skip_special_tokens=True)
155
+
 
156
  result = generated_text[0]
157
  if "ASSISTANT:" in result:
158
  result = result.split("ASSISTANT:")[-1].strip()
159
+
160
  return result
161
 
 
162
  with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo:
163
  gr.Markdown("""
164
  # KOSMOS-2.5 Document AI Demo
165
+
166
  Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images!
167
  This demo showcases three capabilities:
168
+
169
  1. **Markdown Generation**: Convert document images to markdown format
170
  2. **OCR with Bounding Boxes**: Extract text with spatial coordinates
171
  3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat
172
+
173
  Upload a document image (receipt, form, article, etc.) and try different tasks!
174
  """)
175
+
176
  with gr.Tabs():
 
177
  with gr.TabItem("πŸ“ Markdown Generation"):
178
  with gr.Row():
179
  with gr.Column():
 
185
  md_button = gr.Button("Generate Markdown", variant="primary")
186
  with gr.Column():
187
  md_output = gr.Textbox(
188
+ label="Generated Markdown",
189
+ lines=15,
190
  max_lines=20,
191
  show_copy_button=True
192
  )
193
+
 
194
  with gr.TabItem("πŸ” OCR with Bounding Boxes"):
195
  with gr.Row():
196
  with gr.Column():
 
203
  with gr.Column():
204
  with gr.Row():
205
  ocr_text = gr.Textbox(
206
+ label="Extracted Text with Coordinates",
207
  lines=10,
208
  show_copy_button=True
209
  )
210
  ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)")
211
+
 
212
  with gr.TabItem("πŸ’¬ Document Q&A (Chat)"):
213
  with gr.Row():
214
  with gr.Column():
 
229
  chat_button = gr.Button("Get Answer", variant="primary")
230
  with gr.Column():
231
  chat_output = gr.Textbox(
232
+ label="Answer",
233
  lines=8,
234
  show_copy_button=True
235
  )
236
+
237
+ md_button.click(fn=generate_markdown, inputs=[md_image], outputs=[md_output])
238
+ ocr_button.click(fn=generate_ocr, inputs=[ocr_image], outputs=[ocr_text, ocr_vis])
239
+ chat_button.click(fn=generate_chat_response, inputs=[chat_image, chat_question], outputs=[chat_output])
240
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  gr.Markdown("""
242
  ## Example Use Cases:
243
  - **Receipts**: Extract itemized information or ask about totals
244
  - **Forms**: Convert to structured format or answer specific questions
245
  - **Articles**: Get markdown format or ask about content
246
  - **Screenshots**: Extract text or get information about specific elements
247
+
248
  ## Note:
249
  This is a generative model and may occasionally hallucinate. Results should be verified for accuracy.
250
  """)