Spaces:
Running
Running
Synced repo using 'sync_with_huggingface' Github Action
Browse files- gradio_app.py +107 -67
gradio_app.py
CHANGED
@@ -16,14 +16,13 @@ from typing import List
|
|
16 |
import pypdfium2
|
17 |
import gradio as gr
|
18 |
|
|
|
19 |
from surya.models import load_predictors
|
20 |
|
21 |
from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image
|
22 |
|
23 |
from surya.debug.text import draw_text_on_image
|
24 |
-
|
25 |
-
from PIL import Image
|
26 |
-
from surya.recognition.languages import CODE_TO_LANGUAGE, replace_lang_with_code
|
27 |
from surya.table_rec import TableResult
|
28 |
from surya.detection import TextDetectionResult
|
29 |
from surya.recognition import OCRResult
|
@@ -33,15 +32,18 @@ from surya.common.util import rescale_bbox, expand_bbox
|
|
33 |
|
34 |
|
35 |
# just copy from streamlit_app.py
|
36 |
-
def
|
37 |
from pdftext.extraction import plain_text_output
|
|
|
38 |
with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
|
39 |
f.write(pdf_file.getvalue())
|
40 |
f.seek(0)
|
41 |
|
42 |
# Sample the text from the middle of the PDF
|
43 |
page_middle = page_count // 2
|
44 |
-
page_range = range(
|
|
|
|
|
45 |
text = plain_text_output(f.name, page_range=page_range)
|
46 |
|
47 |
sample_gap = len(text) // max_samples
|
@@ -54,24 +56,14 @@ def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pag
|
|
54 |
# Split the text into samples for the model
|
55 |
samples = []
|
56 |
for i in range(0, len(text), sample_gap):
|
57 |
-
samples.append(text[i:i + sample_len])
|
58 |
|
59 |
results = predictors["ocr_error"](samples)
|
60 |
label = "This PDF has good text."
|
61 |
-
if results.labels.count("bad") / len(results.labels) > .2:
|
62 |
label = "This PDF may have garbled or bad OCR text."
|
63 |
return label, results.labels
|
64 |
|
65 |
-
# just copy from streamlit_app.py
|
66 |
-
def inline_detection(img) -> (Image.Image, TextDetectionResult):
|
67 |
-
text_pred = predictors["detection"]([img])[0]
|
68 |
-
text_boxes = [p.bbox for p in text_pred.bboxes]
|
69 |
-
|
70 |
-
inline_pred = predictors["inline_detection"]([img], [text_boxes], include_maps=True)[0]
|
71 |
-
inline_polygons = [p.polygon for p in inline_pred.bboxes]
|
72 |
-
det_img = draw_polys_on_image(inline_polygons, img.copy(), color='blue')
|
73 |
-
return det_img, text_pred, inline_pred
|
74 |
-
|
75 |
# just copy from streamlit_app.py
|
76 |
def text_detection(img) -> (Image.Image, TextDetectionResult):
|
77 |
text_pred = predictors["detection"]([img])[0]
|
@@ -83,27 +75,35 @@ def text_detection(img) -> (Image.Image, TextDetectionResult):
|
|
83 |
def layout_detection(img) -> (Image.Image, LayoutResult):
|
84 |
pred = predictors["layout"]([img])[0]
|
85 |
polygons = [p.polygon for p in pred.bboxes]
|
86 |
-
labels = [
|
87 |
-
|
|
|
|
|
|
|
|
|
88 |
return layout_img, pred
|
89 |
|
90 |
# just copy from streamlit_app.py
|
91 |
-
def table_recognition(
|
|
|
|
|
92 |
if skip_table_detection:
|
93 |
layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
|
94 |
table_imgs = [highres_img]
|
95 |
else:
|
96 |
_, layout_pred = layout_detection(img)
|
97 |
-
layout_tables_lowres = [
|
|
|
|
|
|
|
|
|
98 |
table_imgs = []
|
99 |
layout_tables = []
|
100 |
for tb in layout_tables_lowres:
|
101 |
highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
|
102 |
# Slightly expand the box
|
103 |
highres_bbox = expand_bbox(highres_bbox)
|
104 |
-
table_imgs.append(
|
105 |
-
highres_img.crop(highres_bbox)
|
106 |
-
)
|
107 |
layout_tables.append(highres_bbox)
|
108 |
|
109 |
table_preds = predictors["table_rec"](table_imgs)
|
@@ -115,29 +115,72 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im
|
|
115 |
colors = []
|
116 |
|
117 |
for item in results.cells:
|
118 |
-
adjusted_bboxes.append(
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
124 |
labels.append(item.label)
|
125 |
if "Row" in item.label:
|
126 |
colors.append("blue")
|
127 |
else:
|
128 |
colors.append("red")
|
129 |
-
table_img = draw_bboxes_on_image(
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
return table_img, table_preds
|
131 |
|
132 |
# just copy from streamlit_app.py
|
133 |
-
def ocr(
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
def open_pdf(pdf_file):
|
143 |
return pypdfium2.PdfDocument(pdf_file)
|
@@ -188,13 +231,13 @@ with gr.Blocks(title="Surya") as demo:
|
|
188 |
in_img = gr.Image(label="Select page of Image", type="pil", sources=None)
|
189 |
|
190 |
text_det_btn = gr.Button("Run Text Detection")
|
191 |
-
inline_det_btn = gr.Button("Run Inline Math Detection")
|
192 |
layout_det_btn = gr.Button("Run Layout Analysis")
|
193 |
|
194 |
-
|
|
|
|
|
195 |
text_rec_btn = gr.Button("Run OCR")
|
196 |
|
197 |
-
use_pdf_boxes_ckb = gr.Checkbox(label="Use PDF table boxes", value=True, info="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
|
198 |
skip_table_detection_ckb = gr.Checkbox(label="Skip table detection", value=False, info="Table recognition only: Skip table detection and treat the whole image/page as a table.")
|
199 |
table_rec_btn = gr.Button("Run Table Rec")
|
200 |
|
@@ -202,6 +245,7 @@ with gr.Blocks(title="Surya") as demo:
|
|
202 |
with gr.Column():
|
203 |
result_img = gr.Image(label="Result image")
|
204 |
result_json = gr.JSON(label="Result json")
|
|
|
205 |
|
206 |
def show_image(file, num=1):
|
207 |
if file.endswith('.pdf'):
|
@@ -236,19 +280,6 @@ with gr.Blocks(title="Surya") as demo:
|
|
236 |
inputs=[in_img],
|
237 |
outputs=[result_img, result_json]
|
238 |
)
|
239 |
-
def inline_det_img(pil_image):
|
240 |
-
det_img, text_pred, inline_pred = inline_detection(pil_image)
|
241 |
-
json = {
|
242 |
-
"text": text_pred.model_dump(exclude=["heatmap", "affinity_map"]),
|
243 |
-
"inline": inline_pred.model_dump(exclude=["heatmap", "affinity_map"])
|
244 |
-
}
|
245 |
-
return det_img, json
|
246 |
-
inline_det_btn.click(
|
247 |
-
fn=inline_det_img,
|
248 |
-
inputs=[in_img],
|
249 |
-
outputs=[result_img, result_json]
|
250 |
-
)
|
251 |
-
|
252 |
|
253 |
# Run layout
|
254 |
def layout_det_img(pil_image):
|
@@ -259,21 +290,29 @@ with gr.Blocks(title="Surya") as demo:
|
|
259 |
inputs=[in_img],
|
260 |
outputs=[result_img, result_json]
|
261 |
)
|
|
|
262 |
# Run OCR
|
263 |
-
def text_rec_img(pil_image, in_file, page_number,
|
264 |
if in_file.endswith('.pdf'):
|
265 |
pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
|
266 |
else:
|
267 |
pil_image_highres = pil_image
|
268 |
-
rec_img, pred = ocr(
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
text_rec_btn.click(
|
271 |
fn=text_rec_img,
|
272 |
-
inputs=[in_img, in_file, in_num,
|
273 |
-
outputs=[result_img, result_json]
|
274 |
)
|
|
|
275 |
# Run Table Recognition
|
276 |
-
def table_rec_img(pil_image, in_file, page_number,
|
277 |
if in_file.endswith('.pdf'):
|
278 |
pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
|
279 |
else:
|
@@ -282,20 +321,21 @@ with gr.Blocks(title="Surya") as demo:
|
|
282 |
return table_img, [p.model_dump() for p in pred]
|
283 |
table_rec_btn.click(
|
284 |
fn=table_rec_img,
|
285 |
-
inputs=[in_img, in_file, in_num,
|
286 |
outputs=[result_img, result_json]
|
287 |
)
|
|
|
288 |
# Run bad PDF text detection
|
289 |
-
def ocr_errors_pdf(
|
290 |
-
if
|
291 |
-
count = page_counter(file)
|
292 |
-
else:
|
293 |
raise gr.Error("This feature only works with PDFs.", duration=5)
|
294 |
-
|
|
|
|
|
295 |
return gr.update(label="Result json:" + label, value=results)
|
296 |
ocr_errors_btn.click(
|
297 |
fn=ocr_errors_pdf,
|
298 |
-
inputs=[in_file
|
299 |
outputs=[result_json]
|
300 |
)
|
301 |
|
|
|
16 |
import pypdfium2
|
17 |
import gradio as gr
|
18 |
|
19 |
+
from surya.common.surya.schema import TaskNames
|
20 |
from surya.models import load_predictors
|
21 |
|
22 |
from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image
|
23 |
|
24 |
from surya.debug.text import draw_text_on_image
|
25 |
+
from PIL import Image, ImageDraw
|
|
|
|
|
26 |
from surya.table_rec import TableResult
|
27 |
from surya.detection import TextDetectionResult
|
28 |
from surya.recognition import OCRResult
|
|
|
32 |
|
33 |
|
34 |
# just copy from streamlit_app.py
|
35 |
+
def ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
|
36 |
from pdftext.extraction import plain_text_output
|
37 |
+
|
38 |
with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
|
39 |
f.write(pdf_file.getvalue())
|
40 |
f.seek(0)
|
41 |
|
42 |
# Sample the text from the middle of the PDF
|
43 |
page_middle = page_count // 2
|
44 |
+
page_range = range(
|
45 |
+
max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count)
|
46 |
+
)
|
47 |
text = plain_text_output(f.name, page_range=page_range)
|
48 |
|
49 |
sample_gap = len(text) // max_samples
|
|
|
56 |
# Split the text into samples for the model
|
57 |
samples = []
|
58 |
for i in range(0, len(text), sample_gap):
|
59 |
+
samples.append(text[i : i + sample_len])
|
60 |
|
61 |
results = predictors["ocr_error"](samples)
|
62 |
label = "This PDF has good text."
|
63 |
+
if results.labels.count("bad") / len(results.labels) > 0.2:
|
64 |
label = "This PDF may have garbled or bad OCR text."
|
65 |
return label, results.labels
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
# just copy from streamlit_app.py
|
68 |
def text_detection(img) -> (Image.Image, TextDetectionResult):
|
69 |
text_pred = predictors["detection"]([img])[0]
|
|
|
75 |
def layout_detection(img) -> (Image.Image, LayoutResult):
|
76 |
pred = predictors["layout"]([img])[0]
|
77 |
polygons = [p.polygon for p in pred.bboxes]
|
78 |
+
labels = [
|
79 |
+
f"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}" for p in pred.bboxes
|
80 |
+
]
|
81 |
+
layout_img = draw_polys_on_image(
|
82 |
+
polygons, img.copy(), labels=labels, label_font_size=18
|
83 |
+
)
|
84 |
return layout_img, pred
|
85 |
|
86 |
# just copy from streamlit_app.py
|
87 |
+
def table_recognition(
|
88 |
+
img, highres_img, skip_table_detection: bool
|
89 |
+
) -> (Image.Image, List[TableResult]):
|
90 |
if skip_table_detection:
|
91 |
layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
|
92 |
table_imgs = [highres_img]
|
93 |
else:
|
94 |
_, layout_pred = layout_detection(img)
|
95 |
+
layout_tables_lowres = [
|
96 |
+
line.bbox
|
97 |
+
for line in layout_pred.bboxes
|
98 |
+
if line.label in ["Table", "TableOfContents"]
|
99 |
+
]
|
100 |
table_imgs = []
|
101 |
layout_tables = []
|
102 |
for tb in layout_tables_lowres:
|
103 |
highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
|
104 |
# Slightly expand the box
|
105 |
highres_bbox = expand_bbox(highres_bbox)
|
106 |
+
table_imgs.append(highres_img.crop(highres_bbox))
|
|
|
|
|
107 |
layout_tables.append(highres_bbox)
|
108 |
|
109 |
table_preds = predictors["table_rec"](table_imgs)
|
|
|
115 |
colors = []
|
116 |
|
117 |
for item in results.cells:
|
118 |
+
adjusted_bboxes.append(
|
119 |
+
[
|
120 |
+
(item.bbox[0] + table_bbox[0]),
|
121 |
+
(item.bbox[1] + table_bbox[1]),
|
122 |
+
(item.bbox[2] + table_bbox[0]),
|
123 |
+
(item.bbox[3] + table_bbox[1]),
|
124 |
+
]
|
125 |
+
)
|
126 |
labels.append(item.label)
|
127 |
if "Row" in item.label:
|
128 |
colors.append("blue")
|
129 |
else:
|
130 |
colors.append("red")
|
131 |
+
table_img = draw_bboxes_on_image(
|
132 |
+
adjusted_bboxes,
|
133 |
+
highres_img,
|
134 |
+
labels=labels,
|
135 |
+
label_font_size=18,
|
136 |
+
color=colors,
|
137 |
+
)
|
138 |
return table_img, table_preds
|
139 |
|
140 |
# just copy from streamlit_app.py
|
141 |
+
def ocr(
|
142 |
+
img: Image.Image,
|
143 |
+
highres_img: Image.Image,
|
144 |
+
skip_text_detection: bool = False,
|
145 |
+
recognize_math: bool = True,
|
146 |
+
with_bboxes: bool = True,
|
147 |
+
) -> (Image.Image, OCRResult):
|
148 |
+
if skip_text_detection:
|
149 |
+
img = highres_img
|
150 |
+
bboxes = [[[0, 0, img.width, img.height]]]
|
151 |
+
else:
|
152 |
+
bboxes = None
|
153 |
|
154 |
+
if with_bboxes:
|
155 |
+
tasks = [TaskNames.ocr_with_boxes]
|
156 |
+
else:
|
157 |
+
tasks = [TaskNames.ocr_without_boxes]
|
158 |
+
|
159 |
+
img_pred = predictors["recognition"](
|
160 |
+
[img],
|
161 |
+
task_names=tasks,
|
162 |
+
bboxes=bboxes,
|
163 |
+
det_predictor=predictors["detection"],
|
164 |
+
highres_images=[highres_img],
|
165 |
+
math_mode=recognize_math,
|
166 |
+
return_words=True,
|
167 |
+
)[0]
|
168 |
+
|
169 |
+
bboxes = [line.bbox for line in img_pred.text_lines]
|
170 |
+
text = [line.text for line in img_pred.text_lines]
|
171 |
+
rec_img = draw_text_on_image(bboxes, text, img.size)
|
172 |
+
|
173 |
+
word_boxes = []
|
174 |
+
for line in img_pred.text_lines:
|
175 |
+
if line.words:
|
176 |
+
word_boxes.extend([word.bbox for word in line.words])
|
177 |
+
|
178 |
+
box_img = img.copy()
|
179 |
+
draw = ImageDraw.Draw(box_img)
|
180 |
+
for word_box in word_boxes:
|
181 |
+
draw.rectangle(word_box, outline="red", width=2)
|
182 |
+
|
183 |
+
return rec_img, img_pred, box_img
|
184 |
|
185 |
def open_pdf(pdf_file):
|
186 |
return pypdfium2.PdfDocument(pdf_file)
|
|
|
231 |
in_img = gr.Image(label="Select page of Image", type="pil", sources=None)
|
232 |
|
233 |
text_det_btn = gr.Button("Run Text Detection")
|
|
|
234 |
layout_det_btn = gr.Button("Run Layout Analysis")
|
235 |
|
236 |
+
skip_text_detection_ckb = gr.Checkbox(label="Skip text detection", value=False, info="OCR only: Skip text detection and treat the whole image as a single line.")
|
237 |
+
recognize_math_ckb = gr.Checkbox(label="Recognize math in OCR", value=True, info="Enable math mode in OCR - this will recognize math.")
|
238 |
+
ocr_with_boxes_ckb = gr.Checkbox(label="OCR with boxes", value=True, info="Enable OCR with boxes - this will predict character-level boxes.")
|
239 |
text_rec_btn = gr.Button("Run OCR")
|
240 |
|
|
|
241 |
skip_table_detection_ckb = gr.Checkbox(label="Skip table detection", value=False, info="Table recognition only: Skip table detection and treat the whole image/page as a table.")
|
242 |
table_rec_btn = gr.Button("Run Table Rec")
|
243 |
|
|
|
245 |
with gr.Column():
|
246 |
result_img = gr.Image(label="Result image")
|
247 |
result_json = gr.JSON(label="Result json")
|
248 |
+
ocr_boxes_img = gr.Image(label="OCR boxes image")
|
249 |
|
250 |
def show_image(file, num=1):
|
251 |
if file.endswith('.pdf'):
|
|
|
280 |
inputs=[in_img],
|
281 |
outputs=[result_img, result_json]
|
282 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
# Run layout
|
285 |
def layout_det_img(pil_image):
|
|
|
290 |
inputs=[in_img],
|
291 |
outputs=[result_img, result_json]
|
292 |
)
|
293 |
+
|
294 |
# Run OCR
|
295 |
+
def text_rec_img(pil_image, in_file, page_number, skip_text_detection, recognize_math, ocr_with_boxes):
|
296 |
if in_file.endswith('.pdf'):
|
297 |
pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
|
298 |
else:
|
299 |
pil_image_highres = pil_image
|
300 |
+
rec_img, pred, box_img = ocr(
|
301 |
+
pil_image,
|
302 |
+
pil_image_highres,
|
303 |
+
skip_text_detection,
|
304 |
+
recognize_math,
|
305 |
+
with_bboxes=ocr_with_boxes,
|
306 |
+
)
|
307 |
+
return rec_img, pred.model_dump(), box_img
|
308 |
text_rec_btn.click(
|
309 |
fn=text_rec_img,
|
310 |
+
inputs=[in_img, in_file, in_num, skip_text_detection_ckb, recognize_math_ckb, ocr_with_boxes_ckb],
|
311 |
+
outputs=[result_img, result_json, ocr_boxes_img]
|
312 |
)
|
313 |
+
|
314 |
# Run Table Recognition
|
315 |
+
def table_rec_img(pil_image, in_file, page_number, skip_table_detection):
|
316 |
if in_file.endswith('.pdf'):
|
317 |
pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
|
318 |
else:
|
|
|
321 |
return table_img, [p.model_dump() for p in pred]
|
322 |
table_rec_btn.click(
|
323 |
fn=table_rec_img,
|
324 |
+
inputs=[in_img, in_file, in_num, skip_table_detection_ckb],
|
325 |
outputs=[result_img, result_json]
|
326 |
)
|
327 |
+
|
328 |
# Run bad PDF text detection
|
329 |
+
def ocr_errors_pdf(in_file):
|
330 |
+
if not in_file.endswith('.pdf'):
|
|
|
|
|
331 |
raise gr.Error("This feature only works with PDFs.", duration=5)
|
332 |
+
page_count = page_counter(in_file)
|
333 |
+
io_file = io.BytesIO(open(in_file.name, "rb").read())
|
334 |
+
label, results = ocr_errors(io_file, page_count)
|
335 |
return gr.update(label="Result json:" + label, value=results)
|
336 |
ocr_errors_btn.click(
|
337 |
fn=ocr_errors_pdf,
|
338 |
+
inputs=[in_file],
|
339 |
outputs=[result_json]
|
340 |
)
|
341 |
|