xiaoyao9184 commited on
Commit
8d87e45
·
verified ·
1 Parent(s): 7e431e5

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (1) hide show
  1. gradio_app.py +51 -81
gradio_app.py CHANGED
@@ -9,59 +9,40 @@ if "APP_PATH" in os.environ:
9
  if app_path not in sys.path:
10
  sys.path.append(app_path)
11
 
12
- import gradio as gr
13
-
14
  from typing import List
15
 
16
  import pypdfium2
17
- from pypdfium2 import PdfiumError
18
-
19
- from surya.detection import batch_text_detection
20
- from surya.input.pdflines import get_page_text_lines, get_table_blocks
21
- from surya.layout import batch_layout_detection
22
- from surya.model.detection.model import load_model, load_processor
23
- from surya.model.layout.model import load_model as load_layout_model
24
- from surya.model.layout.processor import load_processor as load_layout_processor
25
- from surya.model.recognition.model import load_model as load_rec_model
26
- from surya.model.recognition.processor import load_processor as load_rec_processor
27
- from surya.model.table_rec.model import load_model as load_table_model
28
- from surya.model.table_rec.processor import load_processor as load_table_processor
29
- from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
30
- from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
31
- from surya.ocr import run_ocr
32
- from surya.postprocessing.text import draw_text_on_image
33
- from PIL import Image
34
- from surya.languages import CODE_TO_LANGUAGE
35
- from surya.input.langs import replace_lang_with_code
36
- from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
37
- from surya.settings import settings
38
- from surya.tables import batch_table_recognition
39
- from surya.postprocessing.util import rescale_bbox
40
- from pdftext.extraction import plain_text_output
41
- from surya.ocr_error import batch_ocr_error_detection
42
-
43
 
44
- def load_det_cached():
45
- return load_model(), load_processor()
46
 
47
- def load_rec_cached():
48
- return load_rec_model(), load_rec_processor()
49
 
50
- def load_layout_cached():
51
- return load_layout_model(), load_layout_processor()
52
 
53
- def load_table_cached():
54
- return load_table_model(), load_table_processor()
 
 
 
 
 
 
55
 
56
- def load_ocr_error_cached():
57
- return load_ocr_error_model(), load_ocr_error_processor()
58
 
59
- #
60
  def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
61
- # Sample the text from the middle of the PDF
62
- page_middle = page_count // 2
63
- page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count))
64
- text = plain_text_output(pdf_file, page_range=page_range)
 
 
 
 
 
65
 
66
  sample_gap = len(text) // max_samples
67
  if len(text) == 0 or sample_gap == 0:
@@ -75,29 +56,29 @@ def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pag
75
  for i in range(0, len(text), sample_gap):
76
  samples.append(text[i:i + sample_len])
77
 
78
- results = batch_ocr_error_detection(samples, ocr_error_model, ocr_error_processor)
79
  label = "This PDF has good text."
80
  if results.labels.count("bad") / len(results.labels) > .2:
81
  label = "This PDF may have garbled or bad OCR text."
82
  return label, results.labels
83
 
84
- #
85
  def text_detection(img) -> (Image.Image, TextDetectionResult):
86
- pred = batch_text_detection([img], det_model, det_processor)[0]
87
  polygons = [p.polygon for p in pred.bboxes]
88
  det_img = draw_polys_on_image(polygons, img.copy())
89
  return det_img, pred
90
 
91
- #
92
  def layout_detection(img) -> (Image.Image, LayoutResult):
93
- pred = batch_layout_detection([img], layout_model, layout_processor)[0]
94
  polygons = [p.polygon for p in pred.bboxes]
95
  labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
96
  layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
97
  return layout_img, pred
98
 
99
- #
100
- def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
101
  if skip_table_detection:
102
  layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
103
  table_imgs = [highres_img]
@@ -108,23 +89,14 @@ def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes:
108
  layout_tables = []
109
  for tb in layout_tables_lowres:
110
  highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
 
 
111
  table_imgs.append(
112
  highres_img.crop(highres_bbox)
113
  )
114
  layout_tables.append(highres_bbox)
115
 
116
- try:
117
- page_text = get_page_text_lines(filepath, [page_idx], [highres_img.size])[0]
118
- table_bboxes = get_table_blocks(layout_tables, page_text, highres_img.size)
119
- except PdfiumError:
120
- # This happens when we try to get text from an image
121
- table_bboxes = [[] for _ in layout_tables]
122
-
123
- if not use_pdf_boxes or any(len(tb) == 0 for tb in table_bboxes):
124
- det_results = batch_text_detection(table_imgs, det_model, det_processor)
125
- table_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
126
-
127
- table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor)
128
  table_img = img.copy()
129
 
130
  for results, table_bbox in zip(table_preds, layout_tables):
@@ -132,7 +104,7 @@ def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes:
132
  labels = []
133
  colors = []
134
 
135
- for item in results.rows + results.cols:
136
  adjusted_bboxes.append([
137
  (item.bbox[0] + table_bbox[0]),
138
  (item.bbox[1] + table_bbox[1]),
@@ -140,31 +112,33 @@ def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes:
140
  (item.bbox[3] + table_bbox[1])
141
  ])
142
  labels.append(item.label)
143
- if hasattr(item, "row_id"):
144
  colors.append("blue")
145
  else:
146
  colors.append("red")
147
  table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors)
148
  return table_img, table_preds
149
 
150
- # Function for OCR
151
  def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
152
  replace_lang_with_code(langs)
153
- img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor, highres_images=[highres_img])[0]
154
 
155
  bboxes = [l.bbox for l in img_pred.text_lines]
156
  text = [l.text for l in img_pred.text_lines]
157
- rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs)
158
  return rec_img, img_pred
159
 
160
  def open_pdf(pdf_file):
161
  return pypdfium2.PdfDocument(pdf_file)
162
 
163
- def count_pdf(pdf_file):
164
  doc = open_pdf(pdf_file)
165
- return len(doc)
 
 
166
 
167
- def get_page_image(pdf_file, page_num, dpi=96):
168
  doc = open_pdf(pdf_file)
169
  renderer = doc.render(
170
  pypdfium2.PdfBitmap.to_pil,
@@ -173,18 +147,14 @@ def get_page_image(pdf_file, page_num, dpi=96):
173
  )
174
  png = list(renderer)[0]
175
  png_image = png.convert("RGB")
 
176
  return png_image
177
 
178
  def get_uploaded_image(in_file):
179
  return Image.open(in_file).convert("RGB")
180
 
181
  # Load models if not already loaded in reload mode
182
- if 'det_model' not in globals():
183
- det_model, det_processor = load_det_cached()
184
- rec_model, rec_processor = load_rec_cached()
185
- layout_model, layout_processor = load_layout_cached()
186
- table_model, table_processor = load_table_cached()
187
- ocr_error_model, ocr_error_processor = load_ocr_error_cached()
188
 
189
  with gr.Blocks(title="Surya") as demo:
190
  gr.Markdown("""
@@ -224,8 +194,8 @@ with gr.Blocks(title="Surya") as demo:
224
 
225
  def show_image(file, num=1):
226
  if file.endswith('.pdf'):
227
- count = count_pdf(file)
228
- img = get_page_image(file, num)
229
  return [
230
  gr.update(visible=True, maximum=count),
231
  gr.update(value=img)]
@@ -283,7 +253,7 @@ with gr.Blocks(title="Surya") as demo:
283
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
284
  else:
285
  pil_image_highres = pil_image
286
- table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)
287
  return table_img, [p.model_dump() for p in pred]
288
  table_rec_btn.click(
289
  fn=table_rec_img,
@@ -293,10 +263,10 @@ with gr.Blocks(title="Surya") as demo:
293
  # Run bad PDF text detection
294
  def ocr_errors_pdf(file, page_count, sample_len=512, max_samples=10, max_pages=15):
295
  if file.endswith('.pdf'):
296
- count = count_pdf(file)
297
  else:
298
  raise gr.Error("This feature only works with PDFs.", duration=5)
299
- label, results = run_ocr_errors(file, count)
300
  return gr.update(label="Result json:" + label, value=results)
301
  ocr_errors_btn.click(
302
  fn=ocr_errors_pdf,
 
9
  if app_path not in sys.path:
10
  sys.path.append(app_path)
11
 
12
+ import io
13
+ import tempfile
14
  from typing import List
15
 
16
  import pypdfium2
17
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ from surya.models import load_predictors
 
20
 
21
+ from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image
 
22
 
23
+ from surya.debug.text import draw_text_on_image
 
24
 
25
+ from PIL import Image
26
+ from surya.recognition.languages import CODE_TO_LANGUAGE, replace_lang_with_code
27
+ from surya.table_rec import TableResult
28
+ from surya.detection import TextDetectionResult
29
+ from surya.recognition import OCRResult
30
+ from surya.layout import LayoutResult
31
+ from surya.settings import settings
32
+ from surya.common.util import rescale_bbox, expand_bbox
33
 
 
 
34
 
35
+ # just copy from streamlit_app.py
36
  def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
37
+ from pdftext.extraction import plain_text_output
38
+ with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
39
+ f.write(pdf_file.getvalue())
40
+ f.seek(0)
41
+
42
+ # Sample the text from the middle of the PDF
43
+ page_middle = page_count // 2
44
+ page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count))
45
+ text = plain_text_output(f.name, page_range=page_range)
46
 
47
  sample_gap = len(text) // max_samples
48
  if len(text) == 0 or sample_gap == 0:
 
56
  for i in range(0, len(text), sample_gap):
57
  samples.append(text[i:i + sample_len])
58
 
59
+ results = predictors["ocr_error"](samples)
60
  label = "This PDF has good text."
61
  if results.labels.count("bad") / len(results.labels) > .2:
62
  label = "This PDF may have garbled or bad OCR text."
63
  return label, results.labels
64
 
65
+ # just copy from streamlit_app.py
66
  def text_detection(img) -> (Image.Image, TextDetectionResult):
67
+ pred = predictors["detection"]([img])[0]
68
  polygons = [p.polygon for p in pred.bboxes]
69
  det_img = draw_polys_on_image(polygons, img.copy())
70
  return det_img, pred
71
 
72
+ # just copy from streamlit_app.py
73
  def layout_detection(img) -> (Image.Image, LayoutResult):
74
+ pred = predictors["layout"]([img])[0]
75
  polygons = [p.polygon for p in pred.bboxes]
76
  labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
77
  layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
78
  return layout_img, pred
79
 
80
+ # just copy from streamlit_app.py
81
+ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
82
  if skip_table_detection:
83
  layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
84
  table_imgs = [highres_img]
 
89
  layout_tables = []
90
  for tb in layout_tables_lowres:
91
  highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
92
+ # Slightly expand the box
93
+ highres_bbox = expand_bbox(highres_bbox)
94
  table_imgs.append(
95
  highres_img.crop(highres_bbox)
96
  )
97
  layout_tables.append(highres_bbox)
98
 
99
+ table_preds = predictors["table_rec"](table_imgs)
 
 
 
 
 
 
 
 
 
 
 
100
  table_img = img.copy()
101
 
102
  for results, table_bbox in zip(table_preds, layout_tables):
 
104
  labels = []
105
  colors = []
106
 
107
+ for item in results.cells:
108
  adjusted_bboxes.append([
109
  (item.bbox[0] + table_bbox[0]),
110
  (item.bbox[1] + table_bbox[1]),
 
112
  (item.bbox[3] + table_bbox[1])
113
  ])
114
  labels.append(item.label)
115
+ if "Row" in item.label:
116
  colors.append("blue")
117
  else:
118
  colors.append("red")
119
  table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors)
120
  return table_img, table_preds
121
 
122
+ # just copy from streamlit_app.py
123
  def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
124
  replace_lang_with_code(langs)
125
+ img_pred = predictors["recognition"]([img], [langs], predictors["detection"], highres_images=[highres_img])[0]
126
 
127
  bboxes = [l.bbox for l in img_pred.text_lines]
128
  text = [l.text for l in img_pred.text_lines]
129
+ rec_img = draw_text_on_image(bboxes, text, img.size, langs)
130
  return rec_img, img_pred
131
 
132
  def open_pdf(pdf_file):
133
  return pypdfium2.PdfDocument(pdf_file)
134
 
135
+ def page_counter(pdf_file):
136
  doc = open_pdf(pdf_file)
137
+ doc_len = len(doc)
138
+ doc.close()
139
+ return doc_len
140
 
141
+ def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
142
  doc = open_pdf(pdf_file)
143
  renderer = doc.render(
144
  pypdfium2.PdfBitmap.to_pil,
 
147
  )
148
  png = list(renderer)[0]
149
  png_image = png.convert("RGB")
150
+ doc.close()
151
  return png_image
152
 
153
  def get_uploaded_image(in_file):
154
  return Image.open(in_file).convert("RGB")
155
 
156
  # Load models if not already loaded in reload mode
157
+ predictors = load_predictors()
 
 
 
 
 
158
 
159
  with gr.Blocks(title="Surya") as demo:
160
  gr.Markdown("""
 
194
 
195
  def show_image(file, num=1):
196
  if file.endswith('.pdf'):
197
+ count = page_counter(file)
198
+ img = get_page_image(file, num, settings.IMAGE_DPI)
199
  return [
200
  gr.update(visible=True, maximum=count),
201
  gr.update(value=img)]
 
253
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
254
  else:
255
  pil_image_highres = pil_image
256
+ table_img, pred = table_recognition(pil_image, pil_image_highres, skip_table_detection)
257
  return table_img, [p.model_dump() for p in pred]
258
  table_rec_btn.click(
259
  fn=table_rec_img,
 
263
  # Run bad PDF text detection
264
  def ocr_errors_pdf(file, page_count, sample_len=512, max_samples=10, max_pages=15):
265
  if file.endswith('.pdf'):
266
+ count = page_counter(file)
267
  else:
268
  raise gr.Error("This feature only works with PDFs.", duration=5)
269
+ label, results = run_ocr_errors(io.BytesIO(open(file.name, "rb").read()), count)
270
  return gr.update(label="Result json:" + label, value=results)
271
  ocr_errors_btn.click(
272
  fn=ocr_errors_pdf,