xiaoyao9184 commited on
Commit
be1602a
·
verified ·
1 Parent(s): 1c761aa

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (2) hide show
  1. gradio_app.py +82 -52
  2. requirements.txt +2 -9
gradio_app.py CHANGED
@@ -9,40 +9,59 @@ if "APP_PATH" in os.environ:
9
  if app_path not in sys.path:
10
  sys.path.append(app_path)
11
 
12
- import io
13
- import tempfile
14
  from typing import List
15
 
16
  import pypdfium2
17
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- from surya.models import load_predictors
20
 
21
- from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image
 
22
 
23
- from surya.debug.text import draw_text_on_image
 
24
 
25
- from PIL import Image
26
- from surya.recognition.languages import CODE_TO_LANGUAGE, replace_lang_with_code
27
- from surya.table_rec import TableResult
28
- from surya.detection import TextDetectionResult
29
- from surya.recognition import OCRResult
30
- from surya.layout import LayoutResult
31
- from surya.settings import settings
32
- from surya.common.util import rescale_bbox, expand_bbox
33
 
 
 
34
 
35
- # just copy from streamlit_app.py
36
- def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
37
- from pdftext.extraction import plain_text_output
38
- with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
39
- f.write(pdf_file.getvalue())
40
- f.seek(0)
41
 
42
- # Sample the text from the middle of the PDF
43
- page_middle = page_count // 2
44
- page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count))
45
- text = plain_text_output(f.name, page_range=page_range)
 
 
46
 
47
  sample_gap = len(text) // max_samples
48
  if len(text) == 0 or sample_gap == 0:
@@ -56,47 +75,56 @@ def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pag
56
  for i in range(0, len(text), sample_gap):
57
  samples.append(text[i:i + sample_len])
58
 
59
- results = predictors["ocr_error"](samples)
60
  label = "This PDF has good text."
61
  if results.labels.count("bad") / len(results.labels) > .2:
62
  label = "This PDF may have garbled or bad OCR text."
63
  return label, results.labels
64
 
65
- # just copy from streamlit_app.py
66
  def text_detection(img) -> (Image.Image, TextDetectionResult):
67
- pred = predictors["detection"]([img])[0]
68
  polygons = [p.polygon for p in pred.bboxes]
69
  det_img = draw_polys_on_image(polygons, img.copy())
70
  return det_img, pred
71
 
72
- # just copy from streamlit_app.py
73
  def layout_detection(img) -> (Image.Image, LayoutResult):
74
- pred = predictors["layout"]([img])[0]
75
  polygons = [p.polygon for p in pred.bboxes]
76
  labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
77
  layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
78
  return layout_img, pred
79
 
80
- # just copy from streamlit_app.py
81
- def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
82
  if skip_table_detection:
83
  layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
84
  table_imgs = [highres_img]
85
  else:
86
  _, layout_pred = layout_detection(img)
87
- layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label in ["Table", "TableOfContents"]]
88
  table_imgs = []
89
  layout_tables = []
90
  for tb in layout_tables_lowres:
91
  highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
92
- # Slightly expand the box
93
- highres_bbox = expand_bbox(highres_bbox)
94
  table_imgs.append(
95
  highres_img.crop(highres_bbox)
96
  )
97
  layout_tables.append(highres_bbox)
98
 
99
- table_preds = predictors["table_rec"](table_imgs)
 
 
 
 
 
 
 
 
 
 
 
100
  table_img = img.copy()
101
 
102
  for results, table_bbox in zip(table_preds, layout_tables):
@@ -104,7 +132,7 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im
104
  labels = []
105
  colors = []
106
 
107
- for item in results.cells:
108
  adjusted_bboxes.append([
109
  (item.bbox[0] + table_bbox[0]),
110
  (item.bbox[1] + table_bbox[1]),
@@ -112,33 +140,31 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im
112
  (item.bbox[3] + table_bbox[1])
113
  ])
114
  labels.append(item.label)
115
- if "Row" in item.label:
116
  colors.append("blue")
117
  else:
118
  colors.append("red")
119
  table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors)
120
  return table_img, table_preds
121
 
122
- # just copy from streamlit_app.py
123
  def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
124
  replace_lang_with_code(langs)
125
- img_pred = predictors["recognition"]([img], [langs], predictors["detection"], highres_images=[highres_img])[0]
126
 
127
  bboxes = [l.bbox for l in img_pred.text_lines]
128
  text = [l.text for l in img_pred.text_lines]
129
- rec_img = draw_text_on_image(bboxes, text, img.size, langs)
130
  return rec_img, img_pred
131
 
132
  def open_pdf(pdf_file):
133
  return pypdfium2.PdfDocument(pdf_file)
134
 
135
- def page_counter(pdf_file):
136
  doc = open_pdf(pdf_file)
137
- doc_len = len(doc)
138
- doc.close()
139
- return doc_len
140
 
141
- def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
142
  doc = open_pdf(pdf_file)
143
  renderer = doc.render(
144
  pypdfium2.PdfBitmap.to_pil,
@@ -147,14 +173,18 @@ def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
147
  )
148
  png = list(renderer)[0]
149
  png_image = png.convert("RGB")
150
- doc.close()
151
  return png_image
152
 
153
  def get_uploaded_image(in_file):
154
  return Image.open(in_file).convert("RGB")
155
 
156
  # Load models if not already loaded in reload mode
157
- predictors = load_predictors()
 
 
 
 
 
158
 
159
  with gr.Blocks(title="Surya") as demo:
160
  gr.Markdown("""
@@ -194,8 +224,8 @@ with gr.Blocks(title="Surya") as demo:
194
 
195
  def show_image(file, num=1):
196
  if file.endswith('.pdf'):
197
- count = page_counter(file)
198
- img = get_page_image(file, num, settings.IMAGE_DPI)
199
  return [
200
  gr.update(visible=True, maximum=count),
201
  gr.update(value=img)]
@@ -253,7 +283,7 @@ with gr.Blocks(title="Surya") as demo:
253
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
254
  else:
255
  pil_image_highres = pil_image
256
- table_img, pred = table_recognition(pil_image, pil_image_highres, skip_table_detection)
257
  return table_img, [p.model_dump() for p in pred]
258
  table_rec_btn.click(
259
  fn=table_rec_img,
@@ -263,10 +293,10 @@ with gr.Blocks(title="Surya") as demo:
263
  # Run bad PDF text detection
264
  def ocr_errors_pdf(file, page_count, sample_len=512, max_samples=10, max_pages=15):
265
  if file.endswith('.pdf'):
266
- count = page_counter(file)
267
  else:
268
  raise gr.Error("This feature only works with PDFs.", duration=5)
269
- label, results = run_ocr_errors(io.BytesIO(open(file.name, "rb").read()), count)
270
  return gr.update(label="Result json:" + label, value=results)
271
  ocr_errors_btn.click(
272
  fn=ocr_errors_pdf,
 
9
  if app_path not in sys.path:
10
  sys.path.append(app_path)
11
 
12
+ import gradio as gr
13
+
14
  from typing import List
15
 
16
  import pypdfium2
17
+ from pypdfium2 import PdfiumError
18
+
19
+ from surya.detection import batch_text_detection
20
+ from surya.input.pdflines import get_page_text_lines, get_table_blocks
21
+ from surya.layout import batch_layout_detection
22
+ from surya.model.detection.model import load_model, load_processor
23
+ from surya.model.layout.model import load_model as load_layout_model
24
+ from surya.model.layout.processor import load_processor as load_layout_processor
25
+ from surya.model.recognition.model import load_model as load_rec_model
26
+ from surya.model.recognition.processor import load_processor as load_rec_processor
27
+ from surya.model.table_rec.model import load_model as load_table_model
28
+ from surya.model.table_rec.processor import load_processor as load_table_processor
29
+ from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
30
+ from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
31
+ from surya.ocr import run_ocr
32
+ from surya.postprocessing.text import draw_text_on_image
33
+ from PIL import Image
34
+ from surya.languages import CODE_TO_LANGUAGE
35
+ from surya.input.langs import replace_lang_with_code
36
+ from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
37
+ from surya.settings import settings
38
+ from surya.tables import batch_table_recognition
39
+ from surya.postprocessing.util import rescale_bbox
40
+ from pdftext.extraction import plain_text_output
41
+ from surya.ocr_error import batch_ocr_error_detection
42
 
 
43
 
44
+ def load_det_cached():
45
+ return load_model(), load_processor()
46
 
47
+ def load_rec_cached():
48
+ return load_rec_model(), load_rec_processor()
49
 
50
+ def load_layout_cached():
51
+ return load_layout_model(), load_layout_processor()
 
 
 
 
 
 
52
 
53
+ def load_table_cached():
54
+ return load_table_model(), load_table_processor()
55
 
56
+ def load_ocr_error_cached():
57
+ return load_ocr_error_model(), load_ocr_error_processor()
 
 
 
 
58
 
59
+ #
60
+ def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
61
+ # Sample the text from the middle of the PDF
62
+ page_middle = page_count // 2
63
+ page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count))
64
+ text = plain_text_output(pdf_file, page_range=page_range)
65
 
66
  sample_gap = len(text) // max_samples
67
  if len(text) == 0 or sample_gap == 0:
 
75
  for i in range(0, len(text), sample_gap):
76
  samples.append(text[i:i + sample_len])
77
 
78
+ results = batch_ocr_error_detection(samples, ocr_error_model, ocr_error_processor)
79
  label = "This PDF has good text."
80
  if results.labels.count("bad") / len(results.labels) > .2:
81
  label = "This PDF may have garbled or bad OCR text."
82
  return label, results.labels
83
 
84
+ #
85
  def text_detection(img) -> (Image.Image, TextDetectionResult):
86
+ pred = batch_text_detection([img], det_model, det_processor)[0]
87
  polygons = [p.polygon for p in pred.bboxes]
88
  det_img = draw_polys_on_image(polygons, img.copy())
89
  return det_img, pred
90
 
91
+ #
92
  def layout_detection(img) -> (Image.Image, LayoutResult):
93
+ pred = batch_layout_detection([img], layout_model, layout_processor)[0]
94
  polygons = [p.polygon for p in pred.bboxes]
95
  labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
96
  layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
97
  return layout_img, pred
98
 
99
+ #
100
+ def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
101
  if skip_table_detection:
102
  layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
103
  table_imgs = [highres_img]
104
  else:
105
  _, layout_pred = layout_detection(img)
106
+ layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
107
  table_imgs = []
108
  layout_tables = []
109
  for tb in layout_tables_lowres:
110
  highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
 
 
111
  table_imgs.append(
112
  highres_img.crop(highres_bbox)
113
  )
114
  layout_tables.append(highres_bbox)
115
 
116
+ try:
117
+ page_text = get_page_text_lines(filepath, [page_idx], [highres_img.size])[0]
118
+ table_bboxes = get_table_blocks(layout_tables, page_text, highres_img.size)
119
+ except PdfiumError:
120
+ # This happens when we try to get text from an image
121
+ table_bboxes = [[] for _ in layout_tables]
122
+
123
+ if not use_pdf_boxes or any(len(tb) == 0 for tb in table_bboxes):
124
+ det_results = batch_text_detection(table_imgs, det_model, det_processor)
125
+ table_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
126
+
127
+ table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor)
128
  table_img = img.copy()
129
 
130
  for results, table_bbox in zip(table_preds, layout_tables):
 
132
  labels = []
133
  colors = []
134
 
135
+ for item in results.rows + results.cols:
136
  adjusted_bboxes.append([
137
  (item.bbox[0] + table_bbox[0]),
138
  (item.bbox[1] + table_bbox[1]),
 
140
  (item.bbox[3] + table_bbox[1])
141
  ])
142
  labels.append(item.label)
143
+ if hasattr(item, "row_id"):
144
  colors.append("blue")
145
  else:
146
  colors.append("red")
147
  table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors)
148
  return table_img, table_preds
149
 
150
+ # Function for OCR
151
  def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
152
  replace_lang_with_code(langs)
153
+ img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor, highres_images=[highres_img])[0]
154
 
155
  bboxes = [l.bbox for l in img_pred.text_lines]
156
  text = [l.text for l in img_pred.text_lines]
157
+ rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs)
158
  return rec_img, img_pred
159
 
160
  def open_pdf(pdf_file):
161
  return pypdfium2.PdfDocument(pdf_file)
162
 
163
+ def count_pdf(pdf_file):
164
  doc = open_pdf(pdf_file)
165
+ return len(doc)
 
 
166
 
167
+ def get_page_image(pdf_file, page_num, dpi=96):
168
  doc = open_pdf(pdf_file)
169
  renderer = doc.render(
170
  pypdfium2.PdfBitmap.to_pil,
 
173
  )
174
  png = list(renderer)[0]
175
  png_image = png.convert("RGB")
 
176
  return png_image
177
 
178
  def get_uploaded_image(in_file):
179
  return Image.open(in_file).convert("RGB")
180
 
181
  # Load models if not already loaded in reload mode
182
+ if 'det_model' not in globals():
183
+ det_model, det_processor = load_det_cached()
184
+ rec_model, rec_processor = load_rec_cached()
185
+ layout_model, layout_processor = load_layout_cached()
186
+ table_model, table_processor = load_table_cached()
187
+ ocr_error_model, ocr_error_processor = load_ocr_error_cached()
188
 
189
  with gr.Blocks(title="Surya") as demo:
190
  gr.Markdown("""
 
224
 
225
  def show_image(file, num=1):
226
  if file.endswith('.pdf'):
227
+ count = count_pdf(file)
228
+ img = get_page_image(file, num)
229
  return [
230
  gr.update(visible=True, maximum=count),
231
  gr.update(value=img)]
 
283
  pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
284
  else:
285
  pil_image_highres = pil_image
286
+ table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)
287
  return table_img, [p.model_dump() for p in pred]
288
  table_rec_btn.click(
289
  fn=table_rec_img,
 
293
  # Run bad PDF text detection
294
  def ocr_errors_pdf(file, page_count, sample_len=512, max_samples=10, max_pages=15):
295
  if file.endswith('.pdf'):
296
+ count = count_pdf(file)
297
  else:
298
  raise gr.Error("This feature only works with PDFs.", duration=5)
299
+ label, results = run_ocr_errors(file, count)
300
  return gr.update(label="Result json:" + label, value=results)
301
  ocr_errors_btn.click(
302
  fn=ocr_errors_pdf,
requirements.txt CHANGED
@@ -1,11 +1,4 @@
1
  torch==2.5.1
2
- surya-ocr==0.10.0
3
  gradio==5.8.0
4
- huggingface-hub==0.26.3
5
- # gradio app need pdftext for run_ocr_errors
6
- pdftext==0.5.0
7
-
8
- # fix compatibility issue keep same with poetry lock file
9
- transformers==4.48.1
10
- # fix https://github.com/gradio-app/gradio/issues/10662
11
- pydantic==2.10.5
 
1
  torch==2.5.1
2
+ surya-ocr==0.8.3
3
  gradio==5.8.0
4
+ huggingface-hub==0.26.3