xiaoyao9184 commited on
Commit
5b927c3
·
verified ·
1 Parent(s): 9a4bd00

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (1) hide show
  1. gradio_app.py +107 -67
gradio_app.py CHANGED
@@ -16,14 +16,13 @@ from typing import List
16
  import pypdfium2
17
  import gradio as gr
18
 
 
19
  from surya.models import load_predictors
20
 
21
  from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image
22
 
23
  from surya.debug.text import draw_text_on_image
24
-
25
- from PIL import Image
26
- from surya.recognition.languages import CODE_TO_LANGUAGE, replace_lang_with_code
27
  from surya.table_rec import TableResult
28
  from surya.detection import TextDetectionResult
29
  from surya.recognition import OCRResult
@@ -33,15 +32,18 @@ from surya.common.util import rescale_bbox, expand_bbox
33
 
34
 
35
  # just copy from streamlit_app.py
36
- def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
37
  from pdftext.extraction import plain_text_output
 
38
  with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
39
  f.write(pdf_file.getvalue())
40
  f.seek(0)
41
 
42
  # Sample the text from the middle of the PDF
43
  page_middle = page_count // 2
44
- page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count))
 
 
45
  text = plain_text_output(f.name, page_range=page_range)
46
 
47
  sample_gap = len(text) // max_samples
@@ -54,24 +56,14 @@ def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pag
54
  # Split the text into samples for the model
55
  samples = []
56
  for i in range(0, len(text), sample_gap):
57
- samples.append(text[i:i + sample_len])
58
 
59
  results = predictors["ocr_error"](samples)
60
  label = "This PDF has good text."
61
- if results.labels.count("bad") / len(results.labels) > .2:
62
  label = "This PDF may have garbled or bad OCR text."
63
  return label, results.labels
64
 
65
- # just copy from streamlit_app.py
66
- def inline_detection(img) -> (Image.Image, TextDetectionResult):
67
- text_pred = predictors["detection"]([img])[0]
68
- text_boxes = [p.bbox for p in text_pred.bboxes]
69
-
70
- inline_pred = predictors["inline_detection"]([img], [text_boxes], include_maps=True)[0]
71
- inline_polygons = [p.polygon for p in inline_pred.bboxes]
72
- det_img = draw_polys_on_image(inline_polygons, img.copy(), color='blue')
73
- return det_img, text_pred, inline_pred
74
-
75
  # just copy from streamlit_app.py
76
  def text_detection(img) -> (Image.Image, TextDetectionResult):
77
  text_pred = predictors["detection"]([img])[0]
@@ -83,27 +75,35 @@ def text_detection(img) -> (Image.Image, TextDetectionResult):
83
  def layout_detection(img) -> (Image.Image, LayoutResult):
84
  pred = predictors["layout"]([img])[0]
85
  polygons = [p.polygon for p in pred.bboxes]
86
- labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
87
- layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
 
 
 
 
88
  return layout_img, pred
89
 
90
  # just copy from streamlit_app.py
91
- def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
 
 
92
  if skip_table_detection:
93
  layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
94
  table_imgs = [highres_img]
95
  else:
96
  _, layout_pred = layout_detection(img)
97
- layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label in ["Table", "TableOfContents"]]
 
 
 
 
98
  table_imgs = []
99
  layout_tables = []
100
  for tb in layout_tables_lowres:
101
  highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
102
  # Slightly expand the box
103
  highres_bbox = expand_bbox(highres_bbox)
104
- table_imgs.append(
105
- highres_img.crop(highres_bbox)
106
- )
107
  layout_tables.append(highres_bbox)
108
 
109
  table_preds = predictors["table_rec"](table_imgs)
@@ -115,29 +115,72 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im
115
  colors = []
116
 
117
  for item in results.cells:
118
- adjusted_bboxes.append([
119
- (item.bbox[0] + table_bbox[0]),
120
- (item.bbox[1] + table_bbox[1]),
121
- (item.bbox[2] + table_bbox[0]),
122
- (item.bbox[3] + table_bbox[1])
123
- ])
 
 
124
  labels.append(item.label)
125
  if "Row" in item.label:
126
  colors.append("blue")
127
  else:
128
  colors.append("red")
129
- table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors)
 
 
 
 
 
 
130
  return table_img, table_preds
131
 
132
  # just copy from streamlit_app.py
133
- def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
134
- replace_lang_with_code(langs)
135
- img_pred = predictors["recognition"]([img], [langs], predictors["detection"], highres_images=[highres_img])[0]
 
 
 
 
 
 
 
 
 
136
 
137
- bboxes = [l.bbox for l in img_pred.text_lines]
138
- text = [l.text for l in img_pred.text_lines]
139
- rec_img = draw_text_on_image(bboxes, text, img.size, langs)
140
- return rec_img, img_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  def open_pdf(pdf_file):
143
  return pypdfium2.PdfDocument(pdf_file)
@@ -188,13 +231,13 @@ with gr.Blocks(title="Surya") as demo:
188
  in_img = gr.Image(label="Select page of Image", type="pil", sources=None)
189
 
190
  text_det_btn = gr.Button("Run Text Detection")
191
- inline_det_btn = gr.Button("Run Inline Math Detection")
192
  layout_det_btn = gr.Button("Run Layout Analysis")
193
 
194
- lang_dd = gr.Dropdown(label="Languages", choices=sorted(list(CODE_TO_LANGUAGE.values())), multiselect=True, max_choices=4, info="Select the languages in the image (if known) to improve OCR accuracy. Optional.")
 
 
195
  text_rec_btn = gr.Button("Run OCR")
196
 
197
- use_pdf_boxes_ckb = gr.Checkbox(label="Use PDF table boxes", value=True, info="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
198
  skip_table_detection_ckb = gr.Checkbox(label="Skip table detection", value=False, info="Table recognition only: Skip table detection and treat the whole image/page as a table.")
199
  table_rec_btn = gr.Button("Run Table Rec")
200
 
@@ -202,6 +245,7 @@ with gr.Blocks(title="Surya") as demo:
202
  with gr.Column():
203
  result_img = gr.Image(label="Result image")
204
  result_json = gr.JSON(label="Result json")
 
205
 
206
  def show_image(file, num=1):
207
  if file.endswith('.pdf'):
@@ -236,19 +280,6 @@ with gr.Blocks(title="Surya") as demo:
236
  inputs=[in_img],
237
  outputs=[result_img, result_json]
238
  )
239
- def inline_det_img(pil_image):
240
- det_img, text_pred, inline_pred = inline_detection(pil_image)
241
- json = {
242
- "text": text_pred.model_dump(exclude=["heatmap", "affinity_map"]),
243
- "inline": inline_pred.model_dump(exclude=["heatmap", "affinity_map"])
244
- }
245
- return det_img, json
246
- inline_det_btn.click(
247
- fn=inline_det_img,
248
- inputs=[in_img],
249
- outputs=[result_img, result_json]
250
- )
251
-
252
 
253
  # Run layout
254
  def layout_det_img(pil_image):
@@ -259,21 +290,29 @@ with gr.Blocks(title="Surya") as demo:
259
  inputs=[in_img],
260
  outputs=[result_img, result_json]
261
  )
 
262
  # Run OCR
263
- def text_rec_img(pil_image, in_file, page_number, languages):
264
  if in_file.endswith('.pdf'):
265
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
266
  else:
267
  pil_image_highres = pil_image
268
- rec_img, pred = ocr(pil_image, pil_image_highres, languages)
269
- return rec_img, pred.model_dump()
 
 
 
 
 
 
270
  text_rec_btn.click(
271
  fn=text_rec_img,
272
- inputs=[in_img, in_file, in_num, lang_dd],
273
- outputs=[result_img, result_json]
274
  )
 
275
  # Run Table Recognition
276
- def table_rec_img(pil_image, in_file, page_number, use_pdf_boxes, skip_table_detection):
277
  if in_file.endswith('.pdf'):
278
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
279
  else:
@@ -282,20 +321,21 @@ with gr.Blocks(title="Surya") as demo:
282
  return table_img, [p.model_dump() for p in pred]
283
  table_rec_btn.click(
284
  fn=table_rec_img,
285
- inputs=[in_img, in_file, in_num, use_pdf_boxes_ckb, skip_table_detection_ckb],
286
  outputs=[result_img, result_json]
287
  )
 
288
  # Run bad PDF text detection
289
- def ocr_errors_pdf(file, page_count, sample_len=512, max_samples=10, max_pages=15):
290
- if file.endswith('.pdf'):
291
- count = page_counter(file)
292
- else:
293
  raise gr.Error("This feature only works with PDFs.", duration=5)
294
- label, results = run_ocr_errors(io.BytesIO(open(file.name, "rb").read()), count)
 
 
295
  return gr.update(label="Result json:" + label, value=results)
296
  ocr_errors_btn.click(
297
  fn=ocr_errors_pdf,
298
- inputs=[in_file, in_num, use_pdf_boxes_ckb, skip_table_detection_ckb],
299
  outputs=[result_json]
300
  )
301
 
 
16
  import pypdfium2
17
  import gradio as gr
18
 
19
+ from surya.common.surya.schema import TaskNames
20
  from surya.models import load_predictors
21
 
22
  from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image
23
 
24
  from surya.debug.text import draw_text_on_image
25
+ from PIL import Image, ImageDraw
 
 
26
  from surya.table_rec import TableResult
27
  from surya.detection import TextDetectionResult
28
  from surya.recognition import OCRResult
 
32
 
33
 
34
  # just copy from streamlit_app.py
35
+ def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
36
  from pdftext.extraction import plain_text_output
37
+
38
  with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
39
  f.write(pdf_file.getvalue())
40
  f.seek(0)
41
 
42
  # Sample the text from the middle of the PDF
43
  page_middle = page_count // 2
44
+ page_range = range(
45
+ max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count)
46
+ )
47
  text = plain_text_output(f.name, page_range=page_range)
48
 
49
  sample_gap = len(text) // max_samples
 
56
  # Split the text into samples for the model
57
  samples = []
58
  for i in range(0, len(text), sample_gap):
59
+ samples.append(text[i : i + sample_len])
60
 
61
  results = predictors["ocr_error"](samples)
62
  label = "This PDF has good text."
63
+ if results.labels.count("bad") / len(results.labels) > 0.2:
64
  label = "This PDF may have garbled or bad OCR text."
65
  return label, results.labels
66
 
 
 
 
 
 
 
 
 
 
 
67
  # just copy from streamlit_app.py
68
  def text_detection(img) -> (Image.Image, TextDetectionResult):
69
  text_pred = predictors["detection"]([img])[0]
 
75
  def layout_detection(img) -> (Image.Image, LayoutResult):
76
  pred = predictors["layout"]([img])[0]
77
  polygons = [p.polygon for p in pred.bboxes]
78
+ labels = [
79
+ f"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}" for p in pred.bboxes
80
+ ]
81
+ layout_img = draw_polys_on_image(
82
+ polygons, img.copy(), labels=labels, label_font_size=18
83
+ )
84
  return layout_img, pred
85
 
86
  # just copy from streamlit_app.py
87
+ def table_recognition(
88
+ img, highres_img, skip_table_detection: bool
89
+ ) -> (Image.Image, List[TableResult]):
90
  if skip_table_detection:
91
  layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
92
  table_imgs = [highres_img]
93
  else:
94
  _, layout_pred = layout_detection(img)
95
+ layout_tables_lowres = [
96
+ line.bbox
97
+ for line in layout_pred.bboxes
98
+ if line.label in ["Table", "TableOfContents"]
99
+ ]
100
  table_imgs = []
101
  layout_tables = []
102
  for tb in layout_tables_lowres:
103
  highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
104
  # Slightly expand the box
105
  highres_bbox = expand_bbox(highres_bbox)
106
+ table_imgs.append(highres_img.crop(highres_bbox))
 
 
107
  layout_tables.append(highres_bbox)
108
 
109
  table_preds = predictors["table_rec"](table_imgs)
 
115
  colors = []
116
 
117
  for item in results.cells:
118
+ adjusted_bboxes.append(
119
+ [
120
+ (item.bbox[0] + table_bbox[0]),
121
+ (item.bbox[1] + table_bbox[1]),
122
+ (item.bbox[2] + table_bbox[0]),
123
+ (item.bbox[3] + table_bbox[1]),
124
+ ]
125
+ )
126
  labels.append(item.label)
127
  if "Row" in item.label:
128
  colors.append("blue")
129
  else:
130
  colors.append("red")
131
+ table_img = draw_bboxes_on_image(
132
+ adjusted_bboxes,
133
+ highres_img,
134
+ labels=labels,
135
+ label_font_size=18,
136
+ color=colors,
137
+ )
138
  return table_img, table_preds
139
 
140
  # just copy from streamlit_app.py
141
+ def ocr(
142
+ img: Image.Image,
143
+ highres_img: Image.Image,
144
+ skip_text_detection: bool = False,
145
+ recognize_math: bool = True,
146
+ with_bboxes: bool = True,
147
+ ) -> (Image.Image, OCRResult):
148
+ if skip_text_detection:
149
+ img = highres_img
150
+ bboxes = [[[0, 0, img.width, img.height]]]
151
+ else:
152
+ bboxes = None
153
 
154
+ if with_bboxes:
155
+ tasks = [TaskNames.ocr_with_boxes]
156
+ else:
157
+ tasks = [TaskNames.ocr_without_boxes]
158
+
159
+ img_pred = predictors["recognition"](
160
+ [img],
161
+ task_names=tasks,
162
+ bboxes=bboxes,
163
+ det_predictor=predictors["detection"],
164
+ highres_images=[highres_img],
165
+ math_mode=recognize_math,
166
+ return_words=True,
167
+ )[0]
168
+
169
+ bboxes = [line.bbox for line in img_pred.text_lines]
170
+ text = [line.text for line in img_pred.text_lines]
171
+ rec_img = draw_text_on_image(bboxes, text, img.size)
172
+
173
+ word_boxes = []
174
+ for line in img_pred.text_lines:
175
+ if line.words:
176
+ word_boxes.extend([word.bbox for word in line.words])
177
+
178
+ box_img = img.copy()
179
+ draw = ImageDraw.Draw(box_img)
180
+ for word_box in word_boxes:
181
+ draw.rectangle(word_box, outline="red", width=2)
182
+
183
+ return rec_img, img_pred, box_img
184
 
185
  def open_pdf(pdf_file):
186
  return pypdfium2.PdfDocument(pdf_file)
 
231
  in_img = gr.Image(label="Select page of Image", type="pil", sources=None)
232
 
233
  text_det_btn = gr.Button("Run Text Detection")
 
234
  layout_det_btn = gr.Button("Run Layout Analysis")
235
 
236
+ skip_text_detection_ckb = gr.Checkbox(label="Skip text detection", value=False, info="OCR only: Skip text detection and treat the whole image as a single line.")
237
+ recognize_math_ckb = gr.Checkbox(label="Recognize math in OCR", value=True, info="Enable math mode in OCR - this will recognize math.")
238
+ ocr_with_boxes_ckb = gr.Checkbox(label="OCR with boxes", value=True, info="Enable OCR with boxes - this will predict character-level boxes.")
239
  text_rec_btn = gr.Button("Run OCR")
240
 
 
241
  skip_table_detection_ckb = gr.Checkbox(label="Skip table detection", value=False, info="Table recognition only: Skip table detection and treat the whole image/page as a table.")
242
  table_rec_btn = gr.Button("Run Table Rec")
243
 
 
245
  with gr.Column():
246
  result_img = gr.Image(label="Result image")
247
  result_json = gr.JSON(label="Result json")
248
+ ocr_boxes_img = gr.Image(label="OCR boxes image")
249
 
250
  def show_image(file, num=1):
251
  if file.endswith('.pdf'):
 
280
  inputs=[in_img],
281
  outputs=[result_img, result_json]
282
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  # Run layout
285
  def layout_det_img(pil_image):
 
290
  inputs=[in_img],
291
  outputs=[result_img, result_json]
292
  )
293
+
294
  # Run OCR
295
+ def text_rec_img(pil_image, in_file, page_number, skip_text_detection, recognize_math, ocr_with_boxes):
296
  if in_file.endswith('.pdf'):
297
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
298
  else:
299
  pil_image_highres = pil_image
300
+ rec_img, pred, box_img = ocr(
301
+ pil_image,
302
+ pil_image_highres,
303
+ skip_text_detection,
304
+ recognize_math,
305
+ with_bboxes=ocr_with_boxes,
306
+ )
307
+ return rec_img, pred.model_dump(), box_img
308
  text_rec_btn.click(
309
  fn=text_rec_img,
310
+ inputs=[in_img, in_file, in_num, skip_text_detection_ckb, recognize_math_ckb, ocr_with_boxes_ckb],
311
+ outputs=[result_img, result_json, ocr_boxes_img]
312
  )
313
+
314
  # Run Table Recognition
315
+ def table_rec_img(pil_image, in_file, page_number, skip_table_detection):
316
  if in_file.endswith('.pdf'):
317
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
318
  else:
 
321
  return table_img, [p.model_dump() for p in pred]
322
  table_rec_btn.click(
323
  fn=table_rec_img,
324
+ inputs=[in_img, in_file, in_num, skip_table_detection_ckb],
325
  outputs=[result_img, result_json]
326
  )
327
+
328
  # Run bad PDF text detection
329
+ def ocr_errors_pdf(in_file):
330
+ if not in_file.endswith('.pdf'):
 
 
331
  raise gr.Error("This feature only works with PDFs.", duration=5)
332
+ page_count = page_counter(in_file)
333
+ io_file = io.BytesIO(open(in_file.name, "rb").read())
334
+ label, results = ocr_errors(io_file, page_count)
335
  return gr.update(label="Result json:" + label, value=results)
336
  ocr_errors_btn.click(
337
  fn=ocr_errors_pdf,
338
+ inputs=[in_file],
339
  outputs=[result_json]
340
  )
341