Spaces:
Build error
Build error
Add extract to excel
Browse files- app.py +110 -32
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -4,24 +4,32 @@ import cv2
|
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
| 6 |
import torch
|
|
|
|
| 7 |
import io
|
| 8 |
# import sys
|
| 9 |
# import json
|
| 10 |
from collections import OrderedDict, defaultdict
|
| 11 |
import xml.etree.ElementTree as ET
|
|
|
|
|
|
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
import matplotlib.patches as patches
|
| 14 |
|
| 15 |
from paddleocr import PaddleOCR
|
| 16 |
-
import pytesseract
|
| 17 |
-
from pytesseract import Output
|
| 18 |
|
| 19 |
import postprocess
|
| 20 |
|
| 21 |
|
| 22 |
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
|
|
|
|
|
|
|
| 23 |
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True)
|
|
|
|
|
|
|
| 24 |
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
|
|
|
|
| 25 |
imgsz = 640
|
| 26 |
|
| 27 |
detection_class_names = ['table', 'table rotated']
|
|
@@ -285,36 +293,36 @@ def visualize_cells(pil_img, cells):
|
|
| 285 |
return PIL.Image.open(img_buf)
|
| 286 |
|
| 287 |
|
| 288 |
-
def pytess(cell_pil_img):
|
| 289 |
-
|
| 290 |
|
| 291 |
|
| 292 |
-
def resize(pil_img, size=1800):
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
|
| 299 |
|
| 300 |
-
def image_smoothening(img):
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
|
| 307 |
|
| 308 |
-
def remove_noise_and_smooth(pil_img):
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
|
| 319 |
|
| 320 |
# def extract_text_from_cells(pil_img, cells):
|
|
@@ -438,6 +446,53 @@ def cells_to_html(cells):
|
|
| 438 |
# return html_code
|
| 439 |
|
| 440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
def main():
|
| 442 |
|
| 443 |
st.set_page_config(layout='wide')
|
|
@@ -453,7 +508,7 @@ def main():
|
|
| 453 |
|
| 454 |
else:
|
| 455 |
tabs = st.tabs(
|
| 456 |
-
['Table Detection', 'Table Structure Recognition']
|
| 457 |
)
|
| 458 |
|
| 459 |
print(filename)
|
|
@@ -462,24 +517,29 @@ def main():
|
|
| 462 |
detection_result = table_detection(pil_img)
|
| 463 |
crop_images, vis_det_img = crop_image(pil_img, detection_result)
|
| 464 |
|
|
|
|
|
|
|
| 465 |
with tabs[0]:
|
|
|
|
| 466 |
st.image(vis_det_img)
|
| 467 |
|
| 468 |
with tabs[1]:
|
| 469 |
-
|
|
|
|
|
|
|
| 470 |
str_cols[0].subheader('Table image')
|
| 471 |
str_cols[1].subheader('OCR result')
|
| 472 |
str_cols[2].subheader('Structure result')
|
| 473 |
str_cols[3].subheader('Cells result')
|
| 474 |
-
str_cols[4].subheader('CSV result')
|
| 475 |
|
| 476 |
for i, img in enumerate(crop_images):
|
| 477 |
ocr_result = ocr(img)
|
| 478 |
structure_result = table_structure(img)
|
| 479 |
table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
|
| 480 |
cells = extract_text_from_cells(cells)
|
|
|
|
| 481 |
html_result = cells_to_html(cells)
|
| 482 |
-
df, csv_result = cells_to_csv(cells)
|
| 483 |
#print(df)
|
| 484 |
|
| 485 |
vis_ocr_img = visualize_ocr(img, ocr_result)
|
|
@@ -490,12 +550,30 @@ def main():
|
|
| 490 |
str_cols[1].image(vis_ocr_img)
|
| 491 |
str_cols[2].image(vis_str_img)
|
| 492 |
str_cols[3].image(vis_cells_img)
|
| 493 |
-
#str_cols[4].dataframe(df)
|
| 494 |
-
str_cols[4].download_button('Download table', csv_result, f'table-{i}.csv', 'text/csv', key=f'download-csv-{i}')
|
| 495 |
|
| 496 |
st.write('\n')
|
| 497 |
st.markdown(html_result, unsafe_allow_html=True)
|
| 498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
if __name__ == '__main__':
|
| 501 |
main()
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import pandas as pd
|
| 6 |
import torch
|
| 7 |
+
import os
|
| 8 |
import io
|
| 9 |
# import sys
|
| 10 |
# import json
|
| 11 |
from collections import OrderedDict, defaultdict
|
| 12 |
import xml.etree.ElementTree as ET
|
| 13 |
+
from tempfile import TemporaryDirectory
|
| 14 |
+
import xlsxwriter
|
| 15 |
import matplotlib.pyplot as plt
|
| 16 |
import matplotlib.patches as patches
|
| 17 |
|
| 18 |
from paddleocr import PaddleOCR
|
| 19 |
+
# import pytesseract
|
| 20 |
+
# from pytesseract import Output
|
| 21 |
|
| 22 |
import postprocess
|
| 23 |
|
| 24 |
|
| 25 |
ocr_instance = PaddleOCR(use_angle_cls=False, lang='en', use_gpu=True)
|
| 26 |
+
|
| 27 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
| 28 |
detection_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/detection_wts.pt', force_reload=True)
|
| 29 |
+
|
| 30 |
+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
| 31 |
structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
|
| 32 |
+
|
| 33 |
imgsz = 640
|
| 34 |
|
| 35 |
detection_class_names = ['table', 'table rotated']
|
|
|
|
| 293 |
return PIL.Image.open(img_buf)
|
| 294 |
|
| 295 |
|
| 296 |
+
# def pytess(cell_pil_img):
|
| 297 |
+
# return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip()
|
| 298 |
|
| 299 |
|
| 300 |
+
# def resize(pil_img, size=1800):
|
| 301 |
+
# length_x, width_y = pil_img.size
|
| 302 |
+
# factor = max(1, size / length_x)
|
| 303 |
+
# size = int(factor * length_x), int(factor * width_y)
|
| 304 |
+
# pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS)
|
| 305 |
+
# return pil_img, factor
|
| 306 |
|
| 307 |
|
| 308 |
+
# def image_smoothening(img):
|
| 309 |
+
# ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY)
|
| 310 |
+
# ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 311 |
+
# blur = cv2.GaussianBlur(th2, (1, 1), 0)
|
| 312 |
+
# ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 313 |
+
# return th3
|
| 314 |
|
| 315 |
|
| 316 |
+
# def remove_noise_and_smooth(pil_img):
|
| 317 |
+
# img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
|
| 318 |
+
# filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3)
|
| 319 |
+
# kernel = np.ones((1, 1), np.uint8)
|
| 320 |
+
# opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
|
| 321 |
+
# closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
|
| 322 |
+
# img = image_smoothening(img)
|
| 323 |
+
# or_image = cv2.bitwise_or(img, closing)
|
| 324 |
+
# pil_img = PIL.Image.fromarray(or_image)
|
| 325 |
+
# return pil_img
|
| 326 |
|
| 327 |
|
| 328 |
# def extract_text_from_cells(pil_img, cells):
|
|
|
|
| 446 |
# return html_code
|
| 447 |
|
| 448 |
|
| 449 |
+
def cells_to_excel(cells, file_path):
|
| 450 |
+
|
| 451 |
+
def int2xlsx(i):
|
| 452 |
+
if i < 26:
|
| 453 |
+
return chr(i + 65)
|
| 454 |
+
return f'{chr(i // 26 + 64)}{chr(i % 26 + 65)}'
|
| 455 |
+
|
| 456 |
+
cells = sorted(cells, key=lambda k: min(k['column_nums']))
|
| 457 |
+
cells = sorted(cells, key=lambda k: min(k['row_nums']))
|
| 458 |
+
|
| 459 |
+
workbook = xlsxwriter.Workbook(file_path)
|
| 460 |
+
|
| 461 |
+
cell_format = workbook.add_format(
|
| 462 |
+
{
|
| 463 |
+
'align': 'center',
|
| 464 |
+
'valign': 'vcenter',
|
| 465 |
+
}
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
worksheet = workbook.add_worksheet(name='Table')
|
| 469 |
+
|
| 470 |
+
table_start_index = 0
|
| 471 |
+
|
| 472 |
+
for cell in cells:
|
| 473 |
+
start_row = min(cell['row_nums'])
|
| 474 |
+
end_row = max(cell['row_nums'])
|
| 475 |
+
start_col = min(cell['column_nums'])
|
| 476 |
+
end_col = max(cell['column_nums'])
|
| 477 |
+
if start_row == end_row and start_col == end_col:
|
| 478 |
+
worksheet.write(
|
| 479 |
+
table_start_index + start_row,
|
| 480 |
+
start_col,
|
| 481 |
+
cell['cell_text'],
|
| 482 |
+
cell_format,
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
if start_col == end_col and start_row == end_row:
|
| 486 |
+
excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}'
|
| 487 |
+
else:
|
| 488 |
+
excel_index = f'{int2xlsx(table_start_index + start_col)}{table_start_index + start_row + 1}:{int2xlsx(table_start_index + end_col)}{table_start_index + end_row + 1}'
|
| 489 |
+
worksheet.merge_range(
|
| 490 |
+
excel_index, cell['cell_text'], cell_format
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
workbook.close()
|
| 494 |
+
|
| 495 |
+
|
| 496 |
def main():
|
| 497 |
|
| 498 |
st.set_page_config(layout='wide')
|
|
|
|
| 508 |
|
| 509 |
else:
|
| 510 |
tabs = st.tabs(
|
| 511 |
+
['Table Detection', 'Table Structure Recognition', 'Extracted Table(s)']
|
| 512 |
)
|
| 513 |
|
| 514 |
print(filename)
|
|
|
|
| 517 |
detection_result = table_detection(pil_img)
|
| 518 |
crop_images, vis_det_img = crop_image(pil_img, detection_result)
|
| 519 |
|
| 520 |
+
all_cells = []
|
| 521 |
+
|
| 522 |
with tabs[0]:
|
| 523 |
+
st.header('Table Detection')
|
| 524 |
st.image(vis_det_img)
|
| 525 |
|
| 526 |
with tabs[1]:
|
| 527 |
+
st.header('Table Structure Recognition')
|
| 528 |
+
|
| 529 |
+
str_cols = st.columns((len(crop_images), ) * 4)
|
| 530 |
str_cols[0].subheader('Table image')
|
| 531 |
str_cols[1].subheader('OCR result')
|
| 532 |
str_cols[2].subheader('Structure result')
|
| 533 |
str_cols[3].subheader('Cells result')
|
|
|
|
| 534 |
|
| 535 |
for i, img in enumerate(crop_images):
|
| 536 |
ocr_result = ocr(img)
|
| 537 |
structure_result = table_structure(img)
|
| 538 |
table_structures, cells, confidence_score = convert_stucture(ocr_result, img, structure_result)
|
| 539 |
cells = extract_text_from_cells(cells)
|
| 540 |
+
all_cells.append(cells)
|
| 541 |
html_result = cells_to_html(cells)
|
| 542 |
+
#df, csv_result = cells_to_csv(cells)
|
| 543 |
#print(df)
|
| 544 |
|
| 545 |
vis_ocr_img = visualize_ocr(img, ocr_result)
|
|
|
|
| 550 |
str_cols[1].image(vis_ocr_img)
|
| 551 |
str_cols[2].image(vis_str_img)
|
| 552 |
str_cols[3].image(vis_cells_img)
|
|
|
|
|
|
|
| 553 |
|
| 554 |
st.write('\n')
|
| 555 |
st.markdown(html_result, unsafe_allow_html=True)
|
| 556 |
|
| 557 |
+
with tabs[2]:
|
| 558 |
+
st.header('Extracted Table(s)')
|
| 559 |
+
for idx, col in enumerate(st.columns(len(all_cells))):
|
| 560 |
+
with col:
|
| 561 |
+
if len(all_cells) > 1:
|
| 562 |
+
st.header(f'Table {idx + 1}')
|
| 563 |
+
|
| 564 |
+
with TemporaryDirectory() as temp_dir_path:
|
| 565 |
+
df = None
|
| 566 |
+
xlsx_path = os.path.join(temp_dir_path, f'debug_{idx}.xlsx')
|
| 567 |
+
cells_to_excel(all_cells[idx], xlsx_path)
|
| 568 |
+
with open(xlsx_path, 'rb') as ref:
|
| 569 |
+
df = pd.read_excel(ref)
|
| 570 |
+
st.dataframe(df)
|
| 571 |
+
st.download_button(
|
| 572 |
+
'Download Excel File',
|
| 573 |
+
ref,
|
| 574 |
+
file_name=f'output_{idx}.xlsx',
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
|
| 578 |
if __name__ == '__main__':
|
| 579 |
main()
|
requirements.txt
CHANGED
|
@@ -76,3 +76,4 @@ setuptools>=65.5.1 # Snyk vulnerability fix
|
|
| 76 |
# Other
|
| 77 |
pytesseract==0.3.10
|
| 78 |
# beautifulsoup4==4.11.1
|
|
|
|
|
|
| 76 |
# Other
|
| 77 |
pytesseract==0.3.10
|
| 78 |
# beautifulsoup4==4.11.1
|
| 79 |
+
xlsxwriter
|