Spaces:
Paused
Paused
Commit
·
ca5b08e
0
Parent(s):
Initial commit for HF Space (no images)
Browse files- .dockerignore +5 -0
- .github/workflows/docker.yml +46 -0
- .gitignore +194 -0
- Dockerfile +43 -0
- LICENSE +201 -0
- README.md +397 -0
- eval/eval.sh +27 -0
- eval/eval_element_merge_detect.py +137 -0
- eval/eval_html_table_merge.py +208 -0
- eval/eval_page_to_markdown.py +76 -0
- eval/eval_page_to_markdown_nanonets.py +160 -0
- eval/eval_page_to_markdown_olmocr.py +157 -0
- eval/eval_table_to_html.py +206 -0
- eval/eval_table_to_html_nanonets.py +295 -0
- eval/eval_table_to_html_olmocr.py +212 -0
- eval/gen_element_merge_detect_data.py +36 -0
- eval/gen_html_table_merge_data.py +32 -0
- eval/parallel.py +50 -0
- ocrflux/check.py +44 -0
- ocrflux/image_utils.py +50 -0
- ocrflux/inference.py +237 -0
- ocrflux/jsonl_to_markdown.py +37 -0
- ocrflux/metrics.py +147 -0
- ocrflux/pipeline.py +861 -0
- ocrflux/prompts.py +60 -0
- ocrflux/table_format.py +143 -0
- ocrflux/work_queue.py +357 -0
- pyproject.toml +75 -0
.dockerignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
**/*
|
3 |
+
.*
|
4 |
+
!ocrflux
|
5 |
+
!pyproject.toml
|
.github/workflows/docker.yml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
on:
|
2 |
+
push:
|
3 |
+
branches:
|
4 |
+
- main
|
5 |
+
- 'v*.*.*'
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
build_and_push_docker:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
|
11 |
+
steps:
|
12 |
+
- name: Checkout repository
|
13 |
+
uses: actions/checkout@v4
|
14 |
+
|
15 |
+
- name: Set up Docker Buildx
|
16 |
+
uses: docker/setup-buildx-action@v3
|
17 |
+
|
18 |
+
- name: Log in to Docker Hub
|
19 |
+
uses: docker/login-action@v3
|
20 |
+
with:
|
21 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
22 |
+
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
23 |
+
|
24 |
+
- name: Determine image tags
|
25 |
+
id: determine_tags
|
26 |
+
run: |
|
27 |
+
BRANCH_NAME=${{ github.ref_name }}
|
28 |
+
DOCKER_IMAGE_NAME="chatdoc/ocrflux"
|
29 |
+
|
30 |
+
if [[ "$BRANCH_NAME" == "main" ]]; then
|
31 |
+
echo "IMAGE_TAGS=$DOCKER_IMAGE_NAME:latest,$DOCKER_IMAGE_NAME:$BRANCH_NAME"
|
32 |
+
echo "image_tags=$DOCKER_IMAGE_NAME:latest,$DOCKER_IMAGE_NAME:$BRANCH_NAME" >> $GITHUB_OUTPUT
|
33 |
+
else
|
34 |
+
echo "IMAGE_TAGS=$DOCKER_IMAGE_NAME:$BRANCH_NAME"
|
35 |
+
echo "image_tags=$DOCKER_IMAGE_NAME:$BRANCH_NAME" >> $GITHUB_OUTPUT
|
36 |
+
fi
|
37 |
+
|
38 |
+
- name: Build and push Docker image
|
39 |
+
id: docker_build
|
40 |
+
uses: docker/build-push-action@v6
|
41 |
+
with:
|
42 |
+
context: .
|
43 |
+
push: true
|
44 |
+
tags: ${{ steps.determine_tags.outputs.image_tags }}
|
45 |
+
cache-from: type=gha,scope=${{ github.workflow }}
|
46 |
+
cache-to: type=gha,scope=${{ github.workflow }},mode=max
|
.gitignore
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# Abstra
|
171 |
+
# Abstra is an AI-powered process automation framework.
|
172 |
+
# Ignore directories containing user credentials, local state, and settings.
|
173 |
+
# Learn more at https://abstra.io/docs
|
174 |
+
.abstra/
|
175 |
+
|
176 |
+
# Visual Studio Code
|
177 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
178 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
179 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
180 |
+
# you could uncomment the following to ignore the enitre vscode folder
|
181 |
+
# .vscode/
|
182 |
+
|
183 |
+
# Ruff stuff:
|
184 |
+
.ruff_cache/
|
185 |
+
|
186 |
+
# PyPI configuration file
|
187 |
+
.pypirc
|
188 |
+
|
189 |
+
# Cursor
|
190 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
191 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
192 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
193 |
+
.cursorignore
|
194 |
+
.cursorindexingignore
|
Dockerfile
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ubuntu:24.04
|
2 |
+
|
3 |
+
WORKDIR /OCRFlux
|
4 |
+
|
5 |
+
ENV LANG=en_US.UTF-8 \
|
6 |
+
PIP_ROOT_USER_ACTION=ignore \
|
7 |
+
PIP_BREAK_SYSTEM_PACKAGES=true \
|
8 |
+
PIP_NO_CACHE_DIR=true \
|
9 |
+
PIP_DISABLE_PIP_VERSION_CHECK=true \
|
10 |
+
PYTHONPATH=/OCRFlux
|
11 |
+
|
12 |
+
SHELL ["/bin/bash", "-c"]
|
13 |
+
|
14 |
+
RUN --mount=type=bind,source=./,target=/builder \
|
15 |
+
cp -a /builder/. /OCRFlux/ && \
|
16 |
+
set -o pipefail && \
|
17 |
+
apt-get update && \
|
18 |
+
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
|
19 |
+
ca-certificates \
|
20 |
+
curl \
|
21 |
+
fonts-crosextra-caladea \
|
22 |
+
fonts-crosextra-carlito \
|
23 |
+
gsfonts \
|
24 |
+
lcdf-typetools \
|
25 |
+
locales \
|
26 |
+
msttcorefonts \
|
27 |
+
poppler-utils \
|
28 |
+
poppler-data \
|
29 |
+
python3.12-dev \
|
30 |
+
python3.12-full \
|
31 |
+
software-properties-common \
|
32 |
+
ttf-mscorefonts-installer && \
|
33 |
+
locale-gen en_US.UTF-8 && \
|
34 |
+
curl https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py && \
|
35 |
+
python3.12 /tmp/get-pip.py && \
|
36 |
+
python3.12 -m pip install . --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ && \
|
37 |
+
rm -rf ./* \
|
38 |
+
/var/lib/apt/lists/* \
|
39 |
+
/tmp/* \
|
40 |
+
/root/.cache/pip &&\
|
41 |
+
find /var/log /var/cache -type f -delete
|
42 |
+
|
43 |
+
ENTRYPOINT ["python3.12", "-m", "ocrflux.pipeline"]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<img src="./images/OCRFlux.png" alt="OCRFlux Logo" width="300"/>
|
3 |
+
<hr/>
|
4 |
+
</div>
|
5 |
+
<p align="center">
|
6 |
+
<a href="https://github.com/chatdoc-com/OCRFlux/blob/main/LICENSE">
|
7 |
+
<img alt="GitHub License" src="./images/license.svg" height="20">
|
8 |
+
</a>
|
9 |
+
<a href="https://github.com/chatdoc-com/OCRFlux/releases">
|
10 |
+
<img alt="GitHub release" src="./images/release.svg" height="20">
|
11 |
+
</a>
|
12 |
+
<a href="https://ocrflux.pdfparser.io/">
|
13 |
+
<img alt="Demo" src="./images/demo.svg" height="20">
|
14 |
+
</a>
|
15 |
+
<a href="https://discord.gg/F33mhsAqqg">
|
16 |
+
<img alt="Discord" src="./images/discord.svg" height="20">
|
17 |
+
</a>
|
18 |
+
</p>
|
19 |
+
|
20 |
+
OCRFlux is a multimodal large language model based toolkit for converting PDFs and images into clean, readable, plain Markdown text. It aims to push the current state-of-the-art to a significantly higher level.
|
21 |
+
|
22 |
+
Try the online demo: [OCRFlux Demo](https://ocrflux.pdfparser.io/)
|
23 |
+
|
24 |
+
Functions: **Whole file parsing**
|
25 |
+
- On each page
|
26 |
+
- Convert into text with a natural reading order, even in the presence of multi-column layouts, figures, and insets
|
27 |
+
- Support for complicated tables and equations
|
28 |
+
- Automatically removes headers and footers
|
29 |
+
|
30 |
+
- Cross-page table/paragraph merging
|
31 |
+
- Cross-page table merging
|
32 |
+
- Cross-page paragraph merging
|
33 |
+
|
34 |
+
|
35 |
+
Key features:
|
36 |
+
- Superior parsing quality on each page
|
37 |
+
|
38 |
+
It respectively achieves 0.095 higher (from 0.872 to 0.967), 0.109 higher (from 0.858 to 0.967) and 0.187 higher (from 0.780 to 0.967) Edit Distance Similarity (EDS) on our released benchmark [OCRFlux-bench-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-single) than the baseline model [olmOCR-7B-0225-preview](https://huggingface.co/allenai/olmOCR-7B-0225-preview), [Nanonets-OCR-s](https://huggingface.co/nanonets/Nanonets-OCR-s) and [MonkeyOCR](https://huggingface.co/echo840/MonkeyOCR).
|
39 |
+
|
40 |
+
- Native support for cross-page table/paragraph merging (to our best this is the first to support this feature in all the open sourced project).
|
41 |
+
|
42 |
+
- Based on a 3B parameter VLM, so it can run even on GTX 3090 GPU.
|
43 |
+
|
44 |
+
Release:
|
45 |
+
- [OCRFlux-3B](https://huggingface.co/ChatDOC/OCRFlux-3B) - 3B parameter VLM
|
46 |
+
- Benchmark for evaluation
|
47 |
+
- [OCRFlux-bench-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-single)
|
48 |
+
- [OCRFlux-pubtabnet-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-single)
|
49 |
+
- [OCRFlux-bench-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-cross)
|
50 |
+
- [OCRFlux-pubtabnet-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-cross)
|
51 |
+
|
52 |
+
|
53 |
+
### News
|
54 |
+
- Jun 17, 2025 - v0.1.0 - Initial public launch and demo.
|
55 |
+
|
56 |
+
### Benchmark for single-page parsing
|
57 |
+
|
58 |
+
We ship two comprehensive benchmarks to help measure the performance of our OCR system in single-page parsing:
|
59 |
+
|
60 |
+
- [OCRFlux-bench-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-single): Containing 2000 pdf pages (1000 English pages and 1000 Chinese pages) and their ground-truth Markdowns (manually labeled with multi-round check).
|
61 |
+
|
62 |
+
- [OCRFlux-pubtabnet-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-single): Derived from the public [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet) benchmark with some format transformation. It contains 9064 HTML table samples, which are split into simple tables and complex tables according to whether they have rowspan and colspan cells.
|
63 |
+
|
64 |
+
We emphasize that the released benchmarks are NOT included in our training and evaluation data. The following is the main result:
|
65 |
+
|
66 |
+
|
67 |
+
1. In [OCRFlux-bench-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-single), we calculated the Edit Distance Similarity (EDS) between the generated Markdowns and the ground-truth Markdowns as the metric.
|
68 |
+
|
69 |
+
<table>
|
70 |
+
<thead>
|
71 |
+
<tr>
|
72 |
+
<th>Language</th>
|
73 |
+
<th>Model</th>
|
74 |
+
<th>Avg EDS ↑</th>
|
75 |
+
</tr>
|
76 |
+
</thead>
|
77 |
+
<tbody>
|
78 |
+
<tr>
|
79 |
+
<td rowspan="4">English</td>
|
80 |
+
<td>olmOCR-7B-0225-preview</td>
|
81 |
+
<td>0.885</td>
|
82 |
+
</tr>
|
83 |
+
<tr>
|
84 |
+
<td>Nanonets-OCR-s</td>
|
85 |
+
<td>0.870</td>
|
86 |
+
</tr>
|
87 |
+
<tr>
|
88 |
+
<td>MonkeyOCR</td>
|
89 |
+
<td>0.828</td>
|
90 |
+
</tr>
|
91 |
+
<tr>
|
92 |
+
<td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
|
93 |
+
<td>0.971</td>
|
94 |
+
</tr>
|
95 |
+
<tr>
|
96 |
+
<td rowspan="4">Chinese</td>
|
97 |
+
<td>olmOCR-7B-0225-preview</td>
|
98 |
+
<td>0.859</td>
|
99 |
+
</tr>
|
100 |
+
<tr>
|
101 |
+
<td>Nanonets-OCR-s</td>
|
102 |
+
<td>0.846</td>
|
103 |
+
</tr>
|
104 |
+
<tr>
|
105 |
+
<td>MonkeyOCR</td>
|
106 |
+
<td>0.731</td>
|
107 |
+
</tr>
|
108 |
+
<tr>
|
109 |
+
<td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
|
110 |
+
<td>0.962</td>
|
111 |
+
</tr>
|
112 |
+
<tr>
|
113 |
+
<td rowspan="4">Total</td>
|
114 |
+
<td>olmOCR-7B-0225-preview</td>
|
115 |
+
<td>0.872</td>
|
116 |
+
</tr>
|
117 |
+
<tr>
|
118 |
+
<td>Nanonets-OCR-s</td>
|
119 |
+
<td>0.858</td>
|
120 |
+
</tr>
|
121 |
+
<tr>
|
122 |
+
<td>MonkeyOCR</td>
|
123 |
+
<td>0.780</td>
|
124 |
+
</tr>
|
125 |
+
<tr>
|
126 |
+
<td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
|
127 |
+
<td>0.967</td>
|
128 |
+
</tr>
|
129 |
+
</tbody>
|
130 |
+
</table>
|
131 |
+
|
132 |
+
2. In [OCRFlux-pubtabnet-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-single), we calculated the Tree Edit Distance-based Similarity (TEDS) between the generated HTML tables and the ground-truth HTML tables as the metric.
|
133 |
+
<table>
|
134 |
+
<thead>
|
135 |
+
<tr>
|
136 |
+
<th>Type</th>
|
137 |
+
<th>Model</th>
|
138 |
+
<th>Avg TEDS ↑</th>
|
139 |
+
</tr>
|
140 |
+
</thead>
|
141 |
+
<tbody>
|
142 |
+
<tr>
|
143 |
+
<td rowspan="4">Simple</td>
|
144 |
+
<td>olmOCR-7B-0225-preview</td>
|
145 |
+
<td>0.810</td>
|
146 |
+
</tr>
|
147 |
+
<tr>
|
148 |
+
<td>Nanonets-OCR-s</td>
|
149 |
+
<td>0.882</td>
|
150 |
+
</tr>
|
151 |
+
<tr>
|
152 |
+
<td>MonkeyOCR</td>
|
153 |
+
<td>0.880</td>
|
154 |
+
</tr>
|
155 |
+
<tr>
|
156 |
+
<td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
|
157 |
+
<td>0.912</td>
|
158 |
+
</tr>
|
159 |
+
<tr>
|
160 |
+
<td rowspan="4">Complex</td>
|
161 |
+
<td>olmOCR-7B-0225-preview</td>
|
162 |
+
<td>0.676</td>
|
163 |
+
</tr>
|
164 |
+
<tr>
|
165 |
+
<td>Nanonets-OCR-s</td>
|
166 |
+
<td>0.772</td>
|
167 |
+
</tr>
|
168 |
+
<tr>
|
169 |
+
<td><strong>MonkeyOCR<strong></td>
|
170 |
+
<td>0.826</td>
|
171 |
+
</tr>
|
172 |
+
<tr>
|
173 |
+
<td><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></td>
|
174 |
+
<td>0.807</td>
|
175 |
+
</tr>
|
176 |
+
<tr>
|
177 |
+
<td rowspan="4">Total</td>
|
178 |
+
<td>olmOCR-7B-0225-preview</td>
|
179 |
+
<td>0.744</td>
|
180 |
+
</tr>
|
181 |
+
<tr>
|
182 |
+
<td>Nanonets-OCR-s</td>
|
183 |
+
<td>0.828</td>
|
184 |
+
</tr>
|
185 |
+
<tr>
|
186 |
+
<td>MonkeyOCR</td>
|
187 |
+
<td>0.853</td>
|
188 |
+
</tr>
|
189 |
+
<tr>
|
190 |
+
<td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
|
191 |
+
<td>0.861</td>
|
192 |
+
</tr>
|
193 |
+
</tbody>
|
194 |
+
</table>
|
195 |
+
|
196 |
+
We also conduct some case studies to show the superiority of our model in the [blog](https://ocrflux.pdfparser.io/#/blog) article.
|
197 |
+
|
198 |
+
### Benchmark for cross-page table/paragraph merging
|
199 |
+
|
200 |
+
PDF documents are typically paginated, which often results in tables or paragraphs being split across consecutive pages. Accurately detecting and merging such cross-page structures is crucial to avoid generating incomplete or fragmented content.
|
201 |
+
|
202 |
+
The detection task can be formulated as follows: given the Markdowns of two consecutive pages—each structured as a list of Markdown elements (e.g., paragraphs and tables)—the goal is to identify the indexes of elements that should be merged across the pages.
|
203 |
+
|
204 |
+
Then for the merging task, if the elements to be merged are paragraphs, we can just concate them. However, for two table fragments, their merging is much more challenging. For example, the table spanning multiple pages will repeat the header of the first page on the second page. Another difficult scenario is that the table cell contains long content that spans multiple lines within the cell, with the first few lines appearing on the previous page and the remaining lines continuing on the next page. We also observe some cases where tables with a large number of columns are split vertically and placed on two consecutive pages. More examples of cross-page tables can be found in our [blog](https://ocrflux.pdfparser.io/#/blog) article. To address these issues, we develop the LLM model for cross-page table merging. Specifically, this model takes two split table fragments as input and generates a complete, well-structured table as output.
|
205 |
+
|
206 |
+
We ship two comprehensive benchmarks to help measure the performance of our OCR system in cross-page table/paragraph detection and merging tasks respectively:
|
207 |
+
|
208 |
+
- [OCRFlux-bench-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-cross): Containing 1000 samples (500 English samples and 500 Chinese samples), each sample contains the Markdown element lists of two consecutive pages, along with the indexes of elements that need to be merged (manually labeled through multiple rounds of review). If no tables or paragraphs require merging, the indexes in the annotation data are left empty.
|
209 |
+
|
210 |
+
- [OCRFlux-pubtabnet-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-cross): Containing 9064 pairs of split table fragments, along with their corresponding ground-truth merged versions.
|
211 |
+
|
212 |
+
The released benchmarks are NOT included in our training and evaluation data neither. The following is the main result:
|
213 |
+
|
214 |
+
1. In [OCRFlux-bench-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-cross), we caculated the Accuracy, Precision, Recall and F1 score as the metric. Notice that the detection results are right only when it accurately judges whether there are elements that need to be merged across the two pages and output the right indexes of them.
|
215 |
+
|
216 |
+
| Language | Precision ↑ | Recall ↑ | F1 ↑ | Accuracy ↑ |
|
217 |
+
|----------|-------------|----------|-------|------------|
|
218 |
+
| English | 0.992 | 0.964 | 0.978 | 0.978 |
|
219 |
+
| Chinese | 1.000 | 0.988 | 0.994 | 0.994 |
|
220 |
+
| Total | 0.996 | 0.976 | 0.986 | 0.986 |
|
221 |
+
|
222 |
+
2. In [OCRFlux-pubtabnet-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-cross), we calculate the Tree Edit Distance-based Similarity (TEDS) between the generated merged table and the ground-truth merged table as the metric.
|
223 |
+
|
224 |
+
| Table type | Avg TEDS ↑ |
|
225 |
+
|------------|--------------|
|
226 |
+
| Simple | 0.965 |
|
227 |
+
| Complex | 0.935 |
|
228 |
+
| Total | 0.950 |
|
229 |
+
|
230 |
+
### Installation
|
231 |
+
|
232 |
+
Requirements:
|
233 |
+
- Recent NVIDIA GPU (tested on RTX 3090, 4090, L40S, A100, H100) with at least 12 GB of GPU RAM
|
234 |
+
- 20GB of free disk space
|
235 |
+
|
236 |
+
You will need to install poppler-utils and additional fonts for rendering PDF images.
|
237 |
+
|
238 |
+
Install dependencies (Ubuntu/Debian)
|
239 |
+
```bash
|
240 |
+
sudo apt-get update
|
241 |
+
sudo apt-get install poppler-utils poppler-data ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
|
242 |
+
```
|
243 |
+
|
244 |
+
Set up a conda environment and install OCRFlux. The requirements for running OCRFlux
|
245 |
+
are difficult to install in an existing python environment, so please do make a clean python environment to install into.
|
246 |
+
```bash
|
247 |
+
conda create -n ocrflux python=3.11
|
248 |
+
conda activate ocrflux
|
249 |
+
|
250 |
+
git clone https://github.com/chatdoc-com/OCRFlux.git
|
251 |
+
cd ocrflux
|
252 |
+
|
253 |
+
pip install -e . --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
|
254 |
+
```
|
255 |
+
|
256 |
+
### Local Usage Example
|
257 |
+
|
258 |
+
For quick testing, try the [web demo](https://5f65ccdc2d4fd2f364.gradio.live). To run locally, a GPU is required, as inference is powered by [vllm](hhttps://github.com/vllm-project/vllm) under the hood.
|
259 |
+
|
260 |
+
- For a pdf document:
|
261 |
+
```bash
|
262 |
+
python -m ocrflux.pipeline ./localworkspace --data test.pdf --model /model_dir/OCRFlux-3B
|
263 |
+
```
|
264 |
+
|
265 |
+
- For an image:
|
266 |
+
```bash
|
267 |
+
python -m ocrflux.pipeline ./localworkspace --data test_page.png --model /model_dir/OCRFlux-3B
|
268 |
+
```
|
269 |
+
|
270 |
+
- For a directory of pdf or images:
|
271 |
+
```bash
|
272 |
+
python -m ocrflux.pipeline ./localworkspace --data test_pdf_dir/* --model /model_dir/OCRFlux-3B
|
273 |
+
```
|
274 |
+
You can set `--skip_cross_page_merge` to skip the cross-page merging in the parsing process to accelerate, it would simply concatenate the parsing results of each page to generate final Markdown of the document.
|
275 |
+
|
276 |
+
Results will be stored as JSONL files in the `./localworkspace/results` directory.
|
277 |
+
|
278 |
+
Each line in JSONL files is a json object with the following fields:
|
279 |
+
|
280 |
+
```
|
281 |
+
{
|
282 |
+
"orig_path": str, # the path to the raw pdf or image file
|
283 |
+
"num_pages": int, # the number of pages in the pdf file
|
284 |
+
"document_text": str, # the Markdown text of the converted pdf or image file
|
285 |
+
"page_texts": dict, # the Markdown texts of each page in the pdf file, the key is the page index and the value is the Markdown text of the page
|
286 |
+
"fallback_pages": [int], # the page indexes that are not converted successfully
|
287 |
+
}
|
288 |
+
```
|
289 |
+
|
290 |
+
### API for directly calling OCRFlux (New)
|
291 |
+
You can use the inference API to directly call OCRFlux in your codes without using an online vllm server like following:
|
292 |
+
|
293 |
+
```
|
294 |
+
from vllm import LLM
|
295 |
+
from ocrflux.inference import parse
|
296 |
+
|
297 |
+
file_path = 'test.pdf'
|
298 |
+
# file_path = 'test.png'
|
299 |
+
llm = LLM(model="model_dir/OCRFlux-3B",gpu_memory_utilization=0.8,max_model_len=8192)
|
300 |
+
result = parse(llm,file_path)
|
301 |
+
if result != None:
|
302 |
+
document_markdown = result['document_text']
|
303 |
+
print(document_markdown)
|
304 |
+
with open('test.md','w') as f:
|
305 |
+
f.write(document_markdown)
|
306 |
+
else:
|
307 |
+
print("Parse failed.")
|
308 |
+
```
|
309 |
+
If parsing is failed or there are fallback pages in the result, you can try to set the argument `max_page_retries` for the `parse` function with a positive integer to get a better result. But it may cause longer inference time.
|
310 |
+
|
311 |
+
### Docker Usage
|
312 |
+
|
313 |
+
Requirements:
|
314 |
+
|
315 |
+
- Docker with GPU support [(NVIDIA Toolkit)](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
|
316 |
+
- Pre-downloaded model: [OCRFlux-3B](https://huggingface.co/ChatDOC/OCRFlux-3B)
|
317 |
+
|
318 |
+
To use OCRFlux in a docker container, you can use the following example command:
|
319 |
+
|
320 |
+
```bash
|
321 |
+
docker run -it --gpus all \
|
322 |
+
-v /path/to/localworkspace:/localworkspace \
|
323 |
+
-v /path/to/test_pdf_dir:/test_pdf_dir/ \
|
324 |
+
-v /path/to/OCRFlux-3B:/OCRFlux-3B \
|
325 |
+
chatdoc/ocrflux:latest /localworkspace --data /test_pdf_dir/* --model /OCRFlux-3B/
|
326 |
+
```
|
327 |
+
|
328 |
+
#### Viewing Results
|
329 |
+
Generate the final Markdown files by running the following command. Generated Markdown files will be in `./localworkspace/markdowns/DOCUMENT_NAME` directory.
|
330 |
+
|
331 |
+
```bash
|
332 |
+
python -m ocrflux.jsonl_to_markdown ./localworkspace
|
333 |
+
```
|
334 |
+
|
335 |
+
### Full documentation for the pipeline
|
336 |
+
|
337 |
+
```bash
|
338 |
+
python -m ocrflux.pipeline --help
|
339 |
+
usage: pipeline.py [-h] [--task {pdf2markdown,merge_pages,merge_tables}] [--data [DATA ...]] [--pages_per_group PAGES_PER_GROUP] [--max_page_retries MAX_PAGE_RETRIES]
|
340 |
+
[--max_page_error_rate MAX_PAGE_ERROR_RATE] [--workers WORKERS] [--model MODEL] [--model_max_context MODEL_MAX_CONTEXT] [--model_chat_template MODEL_CHAT_TEMPLATE]
|
341 |
+
[--target_longest_image_dim TARGET_LONGEST_IMAGE_DIM] [--skip_cross_page_merge] [--port PORT]
|
342 |
+
workspace
|
343 |
+
|
344 |
+
Manager for running millions of PDFs through a batch inference pipeline
|
345 |
+
|
346 |
+
positional arguments:
|
347 |
+
workspace The filesystem path where work will be stored, can be a local folder
|
348 |
+
|
349 |
+
options:
|
350 |
+
-h, --help show this help message and exit
|
351 |
+
--data [DATA ...] List of paths to files to process
|
352 |
+
--pages_per_group PAGES_PER_GROUP
|
353 |
+
Aiming for this many pdf pages per work item group
|
354 |
+
--max_page_retries MAX_PAGE_RETRIES
|
355 |
+
Max number of times we will retry rendering a page
|
356 |
+
--max_page_error_rate MAX_PAGE_ERROR_RATE
|
357 |
+
Rate of allowable failed pages in a document, 1/250 by default
|
358 |
+
--workers WORKERS Number of workers to run at a time
|
359 |
+
--model MODEL The path to the model
|
360 |
+
--model_max_context MODEL_MAX_CONTEXT
|
361 |
+
Maximum context length that the model was fine tuned under
|
362 |
+
--model_chat_template MODEL_CHAT_TEMPLATE
|
363 |
+
Chat template to pass to vllm server
|
364 |
+
--target_longest_image_dim TARGET_LONGEST_IMAGE_DIM
|
365 |
+
Dimension on longest side to use for rendering the pdf pages
|
366 |
+
--skip_cross_page_merge
|
367 |
+
Whether to skip cross-page merging
|
368 |
+
--port PORT Port to use for the VLLM server
|
369 |
+
```
|
370 |
+
|
371 |
+
## Code overview
|
372 |
+
|
373 |
+
There are some nice reusable pieces of the code that may be useful for your own projects:
|
374 |
+
- Processing millions of PDFs through our released model using VLLM - [pipeline.py](https://github.com/chatdoc-com/OCRFlux/blob/main/ocrflux/pipeline.py)
|
375 |
+
- Generating final Markdowns from jsonl files - [jsonl_to_markdown.py](https://github.com/chatdoc-com/OCRFlux/blob/main/ocrflux/jsonl_to_markdown.py)
|
376 |
+
- Evaluating the model on the single-page parsing task - [eval_page_to_markdown.py](https://github.com/chatdoc-com/OCRFlux/blob/main/eval/eval_page_to_markdown.py)
|
377 |
+
- Evaluating the model on the table parising task - [eval_table_to_html.py](https://github.com/chatdoc-com/OCRFlux/blob/main/eval/eval_table_to_html.py)
|
378 |
+
- Evaluating the model on the paragraphs/tables merging detection task - [eval_element_merge_detect.py](https://github.com/chatdoc-com/OCRFlux/blob/main/eval/eval_element_merge_detect.py)
|
379 |
+
- Evaluating the model on the table merging task - [eval_html_table_merge.py](https://github.com/chatdoc-com/OCRFlux/blob/main/eval/eval_html_table_merge.py)
|
380 |
+
|
381 |
+
|
382 |
+
## Team
|
383 |
+
|
384 |
+
<!-- start team -->
|
385 |
+
|
386 |
+
**OCRFlux** is developed and maintained by the ChatDOC team, backed by [ChatDOC](https://chatdoc.com/).
|
387 |
+
|
388 |
+
<!-- end team -->
|
389 |
+
|
390 |
+
## License
|
391 |
+
|
392 |
+
<!-- start license -->
|
393 |
+
|
394 |
+
**OCRFlux** is licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
395 |
+
A full copy of the license can be found [on GitHub](https://github.com/allenai/OCRFlux/blob/main/LICENSE).
|
396 |
+
|
397 |
+
<!-- end license -->
|
eval/eval.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# page_to_markdown task
|
4 |
+
python -m ocrflux.pipeline ./eval_page_to_markdown_result --task pdf2markdown --data /data/OCRFlux-bench-single/pdfs/*.pdf --model /data/OCRFlux-7B
|
5 |
+
|
6 |
+
python -m eval.eval_page_to_markdown ./eval_page_to_markdown_result --gt_file /data/OCRFlux-bench-single/data.jsonl
|
7 |
+
|
8 |
+
# element_merge_detect task
|
9 |
+
python -m eval.gen_element_merge_detect_data /data/OCRFlux-bench-cross
|
10 |
+
|
11 |
+
python -m ocrflux.pipeline ./eval_element_merge_detect_result --task merge_pages --data /data/OCRFlux-bench-cross/jsons/*.json --model /data/OCRFlux-7B
|
12 |
+
|
13 |
+
python -m eval.eval_element_merge_detect ./eval_element_merge_detect_result --gt_file /data/OCRFlux-bench-cross/data.jsonl
|
14 |
+
|
15 |
+
# table_to_html task
|
16 |
+
python -m ocrflux.pipeline ./eval_table_to_html_result --task pdf2markdown --data /data/OCRFlux-pubtabnet-single/images/*.png --model /data/OCRFlux-7B
|
17 |
+
|
18 |
+
python -m eval.eval_table_to_html ./eval_table_to_html_result --gt_file /data/OCRFlux-pubtabnet-single/data.jsonl
|
19 |
+
|
20 |
+
# html_table_merge task
|
21 |
+
python -m eval.gen_html_table_merge_data /data/OCRFlux-pubtabnet-cross
|
22 |
+
|
23 |
+
python -m ocrflux.pipeline ./eval_html_table_merge_result --task merge_tables --data /data/OCRFlux-pubtabnet-cross/jsons/*.json --model /data/OCRFlux-7B
|
24 |
+
|
25 |
+
python -m eval.eval_html_table_merge ./eval_html_table_merge_result --gt_file /data/OCRFlux-pubtabnet-cross/data.jsonl
|
26 |
+
|
27 |
+
|
eval/eval_element_merge_detect.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import nltk
|
5 |
+
from tqdm import tqdm
|
6 |
+
from eval.parallel import parallel_process
|
7 |
+
|
8 |
+
|
9 |
+
def evaluate(pred, gt):
|
10 |
+
pred = sorted(pred, key=lambda x: (x[0], x[1]))
|
11 |
+
gt = sorted(gt, key=lambda x: (x[0], x[1]))
|
12 |
+
if pred == gt:
|
13 |
+
return 1
|
14 |
+
else:
|
15 |
+
return 0
|
16 |
+
|
17 |
+
def main():
|
18 |
+
parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task")
|
19 |
+
parser.add_argument(
|
20 |
+
"workspace",
|
21 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--gt_file",
|
25 |
+
help="Ground truth file",
|
26 |
+
)
|
27 |
+
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
pred_data = {}
|
31 |
+
root_dir = os.path.join(args.workspace, "results")
|
32 |
+
for jsonl_file in os.listdir(root_dir):
|
33 |
+
if jsonl_file.endswith(".jsonl"):
|
34 |
+
with open(os.path.join(root_dir, jsonl_file), "r") as f:
|
35 |
+
for line in f:
|
36 |
+
data = json.loads(line)
|
37 |
+
pred_data[os.path.basename(data['orig_path'])] = data['merge_pairs']
|
38 |
+
|
39 |
+
filename_list_en = []
|
40 |
+
filename_list_zh = []
|
41 |
+
gt_data = {}
|
42 |
+
with open(args.gt_file, "r") as f:
|
43 |
+
for line in f:
|
44 |
+
data = json.loads(line)
|
45 |
+
pdf_name_1 = data['pdf_name_1'].split(".")[0]
|
46 |
+
pdf_name_2 = data['pdf_name_2'].split(".")[0]
|
47 |
+
|
48 |
+
pdf_name,page_1 = pdf_name_1.split('_')
|
49 |
+
pdf_name,page_2 = pdf_name_2.split('_')
|
50 |
+
|
51 |
+
json_name = pdf_name + '_' + page_1 + '_' + page_2 + '.json'
|
52 |
+
gt_data[json_name] = data['merging_idx_pairs']
|
53 |
+
|
54 |
+
if data['language'] == 'en':
|
55 |
+
filename_list_en.append(json_name)
|
56 |
+
else:
|
57 |
+
filename_list_zh.append(json_name)
|
58 |
+
|
59 |
+
keys = list(gt_data.keys())
|
60 |
+
if args.n_jobs == 1:
|
61 |
+
scores = [evaluate(pred_data.get(filename, []), gt_data.get(filename, [])) for filename in tqdm(keys)]
|
62 |
+
else:
|
63 |
+
inputs = [{'pred': pred_data.get(filename, []), 'gt': gt_data.get(filename, [])} for filename in keys]
|
64 |
+
scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
|
65 |
+
|
66 |
+
tp_en = 0
|
67 |
+
tn_en = 0
|
68 |
+
fp_en = 0
|
69 |
+
fn_en = 0
|
70 |
+
tp_zh = 0
|
71 |
+
tn_zh = 0
|
72 |
+
fp_zh = 0
|
73 |
+
fn_zh = 0
|
74 |
+
score_en = 0
|
75 |
+
score_zh = 0
|
76 |
+
num_en = 0
|
77 |
+
num_zh = 0
|
78 |
+
for filename, score in zip(keys, scores):
|
79 |
+
print(filename)
|
80 |
+
print(score)
|
81 |
+
print()
|
82 |
+
pred_label = pred_data[filename]
|
83 |
+
if filename in filename_list_en:
|
84 |
+
if pred_label == []:
|
85 |
+
if score == 1:
|
86 |
+
tn_en += 1
|
87 |
+
else:
|
88 |
+
fn_en += 1
|
89 |
+
else:
|
90 |
+
if score == 1:
|
91 |
+
tp_en += 1
|
92 |
+
else:
|
93 |
+
fp_en += 1
|
94 |
+
score_en += score
|
95 |
+
num_en += 1
|
96 |
+
|
97 |
+
elif filename in filename_list_zh:
|
98 |
+
if pred_label == []:
|
99 |
+
if score == 1:
|
100 |
+
tn_zh += 1
|
101 |
+
else:
|
102 |
+
fn_zh += 1
|
103 |
+
else:
|
104 |
+
if score == 1:
|
105 |
+
tp_zh += 1
|
106 |
+
else:
|
107 |
+
fp_zh += 1
|
108 |
+
score_zh += score
|
109 |
+
num_zh += 1
|
110 |
+
|
111 |
+
precision_en = tp_en / (tp_en + fp_en)
|
112 |
+
recall_en = tp_en / (tp_en + fn_en)
|
113 |
+
f1_en = 2*precision_en*recall_en / (precision_en+recall_en)
|
114 |
+
acc_en = score_en / num_en
|
115 |
+
|
116 |
+
precision_zh = tp_zh / (tp_zh + fp_zh)
|
117 |
+
recall_zh = tp_zh / (tp_zh + fn_zh)
|
118 |
+
f1_zh = 2*precision_zh*recall_zh / (precision_zh+recall_zh)
|
119 |
+
acc_zh = score_zh / num_zh
|
120 |
+
|
121 |
+
tp = tp_en + tp_zh
|
122 |
+
fp = fp_en + fp_zh
|
123 |
+
fn = fn_en + fn_zh
|
124 |
+
score = score_en + score_zh
|
125 |
+
num = num_en + num_zh
|
126 |
+
|
127 |
+
precision = tp / (tp + fp)
|
128 |
+
recall = tp / (tp + fn)
|
129 |
+
f1 = 2*precision*recall / (precision+recall)
|
130 |
+
acc = score / num
|
131 |
+
|
132 |
+
print(f"EN: {precision_en} / {recall_en} / {f1_en} / {acc_en}")
|
133 |
+
print(f"ZH: {precision_zh} / {recall_zh} / {f1_zh} / {acc_zh}")
|
134 |
+
print(f"ALL: {precision} / {recall} / {f1} / {acc}")
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
main()
|
eval/eval_html_table_merge.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import distance
|
5 |
+
from apted import APTED, Config
|
6 |
+
from apted.helpers import Tree
|
7 |
+
from lxml import etree, html
|
8 |
+
from collections import deque
|
9 |
+
from tqdm import tqdm
|
10 |
+
from eval.parallel import parallel_process
|
11 |
+
|
12 |
+
|
13 |
+
class TableTree(Tree):
|
14 |
+
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
15 |
+
self.tag = tag
|
16 |
+
self.colspan = colspan
|
17 |
+
self.rowspan = rowspan
|
18 |
+
self.content = content
|
19 |
+
self.children = list(children)
|
20 |
+
|
21 |
+
def bracket(self):
|
22 |
+
"""Show tree using brackets notation"""
|
23 |
+
if self.tag == 'td':
|
24 |
+
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
25 |
+
(self.tag, self.colspan, self.rowspan, self.content)
|
26 |
+
else:
|
27 |
+
result = '"tag": %s' % self.tag
|
28 |
+
for child in self.children:
|
29 |
+
result += child.bracket()
|
30 |
+
return "{{{}}}".format(result)
|
31 |
+
|
32 |
+
|
33 |
+
class CustomConfig(Config):
|
34 |
+
@staticmethod
|
35 |
+
def maximum(*sequences):
|
36 |
+
"""Get maximum possible value
|
37 |
+
"""
|
38 |
+
return max(map(len, sequences))
|
39 |
+
|
40 |
+
def normalized_distance(self, *sequences):
|
41 |
+
"""Get distance from 0 to 1
|
42 |
+
"""
|
43 |
+
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
44 |
+
|
45 |
+
def rename(self, node1, node2):
|
46 |
+
"""Compares attributes of trees"""
|
47 |
+
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
48 |
+
return 1.
|
49 |
+
if node1.tag == 'td':
|
50 |
+
if node1.content or node2.content:
|
51 |
+
return self.normalized_distance(node1.content, node2.content)
|
52 |
+
return 0.
|
53 |
+
|
54 |
+
|
55 |
+
class TEDS(object):
|
56 |
+
''' Tree Edit Distance basead Similarity
|
57 |
+
'''
|
58 |
+
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
59 |
+
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
60 |
+
self.structure_only = structure_only
|
61 |
+
self.n_jobs = n_jobs
|
62 |
+
self.ignore_nodes = ignore_nodes
|
63 |
+
self.__tokens__ = []
|
64 |
+
|
65 |
+
def tokenize(self, node):
|
66 |
+
''' Tokenizes table cells
|
67 |
+
'''
|
68 |
+
self.__tokens__.append('<%s>' % node.tag)
|
69 |
+
if node.text is not None:
|
70 |
+
self.__tokens__ += list(node.text)
|
71 |
+
for n in node.getchildren():
|
72 |
+
self.tokenize(n)
|
73 |
+
if node.tag != 'unk':
|
74 |
+
self.__tokens__.append('</%s>' % node.tag)
|
75 |
+
if node.tag != 'td' and node.tail is not None:
|
76 |
+
self.__tokens__ += list(node.tail)
|
77 |
+
|
78 |
+
def load_html_tree(self, node, parent=None):
|
79 |
+
''' Converts HTML tree to the format required by apted
|
80 |
+
'''
|
81 |
+
global __tokens__
|
82 |
+
if node.tag == 'td':
|
83 |
+
if self.structure_only:
|
84 |
+
cell = []
|
85 |
+
else:
|
86 |
+
self.__tokens__ = []
|
87 |
+
self.tokenize(node)
|
88 |
+
cell = self.__tokens__[1:-1].copy()
|
89 |
+
new_node = TableTree(node.tag,
|
90 |
+
int(node.attrib.get('colspan', '1')),
|
91 |
+
int(node.attrib.get('rowspan', '1')),
|
92 |
+
cell, *deque())
|
93 |
+
else:
|
94 |
+
new_node = TableTree(node.tag, None, None, None, *deque())
|
95 |
+
if parent is not None:
|
96 |
+
parent.children.append(new_node)
|
97 |
+
if node.tag != 'td':
|
98 |
+
for n in node.getchildren():
|
99 |
+
self.load_html_tree(n, new_node)
|
100 |
+
if parent is None:
|
101 |
+
return new_node
|
102 |
+
|
103 |
+
def evaluate(self, pred, true):
|
104 |
+
''' Computes TEDS score between the prediction and the ground truth of a
|
105 |
+
given sample
|
106 |
+
'''
|
107 |
+
if (not pred) or (not true):
|
108 |
+
return 0.0
|
109 |
+
pred = "<html>" + pred + "</html>"
|
110 |
+
true = "<html>" + true + "</html>"
|
111 |
+
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
112 |
+
pred = html.fromstring(pred, parser=parser)
|
113 |
+
true = html.fromstring(true, parser=parser)
|
114 |
+
if pred.xpath('body/table') and true.xpath('body/table'):
|
115 |
+
pred = pred.xpath('body/table')[0]
|
116 |
+
true = true.xpath('body/table')[0]
|
117 |
+
if self.ignore_nodes:
|
118 |
+
etree.strip_tags(pred, *self.ignore_nodes)
|
119 |
+
etree.strip_tags(true, *self.ignore_nodes)
|
120 |
+
n_nodes_pred = len(pred.xpath(".//*"))
|
121 |
+
n_nodes_true = len(true.xpath(".//*"))
|
122 |
+
n_nodes = max(n_nodes_pred, n_nodes_true)
|
123 |
+
tree_pred = self.load_html_tree(pred)
|
124 |
+
tree_true = self.load_html_tree(true)
|
125 |
+
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
|
126 |
+
return 1.0 - (float(distance) / n_nodes)
|
127 |
+
else:
|
128 |
+
return 0.0
|
129 |
+
|
130 |
+
def batch_evaluate(self, pred_json, true_json):
|
131 |
+
''' Computes TEDS score between the prediction and the ground truth of
|
132 |
+
a batch of samples
|
133 |
+
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
134 |
+
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
135 |
+
@output: {'FILENAME': 'TEDS SCORE', ...}
|
136 |
+
'''
|
137 |
+
samples = true_json.keys()
|
138 |
+
if self.n_jobs == 1:
|
139 |
+
scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
140 |
+
else:
|
141 |
+
inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
142 |
+
scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
143 |
+
total_score_simple = 0
|
144 |
+
num_simple = 0
|
145 |
+
total_score_complex = 0
|
146 |
+
num_complex = 0
|
147 |
+
total_score = 0
|
148 |
+
num_total = 0
|
149 |
+
for filename,score in zip(samples, scores):
|
150 |
+
print(filename)
|
151 |
+
print(score)
|
152 |
+
print('')
|
153 |
+
if true_json[filename]['type'] == 'simple':
|
154 |
+
total_score_simple += score
|
155 |
+
num_simple += 1
|
156 |
+
elif true_json[filename]['type'] == 'complex':
|
157 |
+
total_score_complex += score
|
158 |
+
num_complex += 1
|
159 |
+
else:
|
160 |
+
raise ValueError('Unknown type: %s' % true_json[filename]['type'])
|
161 |
+
total_score += score
|
162 |
+
num_total += 1
|
163 |
+
if num_simple > 0:
|
164 |
+
avg_score_simple = total_score_simple / num_simple
|
165 |
+
else:
|
166 |
+
avg_score_simple = 0
|
167 |
+
if num_complex > 0:
|
168 |
+
avg_score_complex = total_score_complex / num_complex
|
169 |
+
else:
|
170 |
+
avg_score_complex = 0
|
171 |
+
avg_score = total_score / num_total
|
172 |
+
print({'simple': (num_simple,avg_score_simple), 'complex': (num_complex,avg_score_complex), 'total': (num_total,avg_score)})
|
173 |
+
|
174 |
+
def main():
|
175 |
+
parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
|
176 |
+
parser.add_argument(
|
177 |
+
"workspace",
|
178 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--gt_file",
|
182 |
+
help="Ground truth file",
|
183 |
+
)
|
184 |
+
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
|
185 |
+
args = parser.parse_args()
|
186 |
+
|
187 |
+
pred_data = {}
|
188 |
+
root_dir = os.path.join(args.workspace, "results")
|
189 |
+
for jsonl_file in os.listdir(root_dir):
|
190 |
+
if jsonl_file.endswith(".jsonl"):
|
191 |
+
with open(os.path.join(root_dir, jsonl_file), "r") as f:
|
192 |
+
for line in f:
|
193 |
+
data = json.loads(line)
|
194 |
+
key = os.path.basename(data['orig_path']).split('.')[0]
|
195 |
+
pred_data[key] = data['merged_tables']
|
196 |
+
|
197 |
+
gt_data = {}
|
198 |
+
with open(args.gt_file, "r") as f:
|
199 |
+
for line in f:
|
200 |
+
data = json.loads(line)
|
201 |
+
key = data['image_name'].split('.')[0]
|
202 |
+
gt_data[key] = {'html':data['gt_table'], 'type':data['type']}
|
203 |
+
|
204 |
+
teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
|
205 |
+
teds.batch_evaluate(pred_data, gt_data)
|
206 |
+
|
207 |
+
if __name__ == "__main__":
|
208 |
+
main()
|
eval/eval_page_to_markdown.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import nltk
|
5 |
+
from tqdm import tqdm
|
6 |
+
from eval.parallel import parallel_process
|
7 |
+
|
8 |
+
|
9 |
+
def evaluate(pred, gt):
|
10 |
+
edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
|
11 |
+
return 1.0 - edit_dist
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
|
16 |
+
parser.add_argument(
|
17 |
+
"workspace",
|
18 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--gt_file",
|
22 |
+
help="Ground truth file",
|
23 |
+
)
|
24 |
+
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
pred_data = {}
|
28 |
+
root_dir = os.path.join(args.workspace, "results")
|
29 |
+
for jsonl_file in os.listdir(root_dir):
|
30 |
+
if jsonl_file.endswith(".jsonl"):
|
31 |
+
with open(os.path.join(root_dir, jsonl_file), "r") as f:
|
32 |
+
for line in f:
|
33 |
+
data = json.loads(line)
|
34 |
+
pred_data[os.path.basename(data['orig_path'])] = data['document_text']
|
35 |
+
|
36 |
+
filename_list_en = []
|
37 |
+
filename_list_zh = []
|
38 |
+
gt_data = {}
|
39 |
+
with open(args.gt_file, "r") as f:
|
40 |
+
for line in f:
|
41 |
+
data = json.loads(line)
|
42 |
+
markdown = data['markdown']
|
43 |
+
pdf_name = data['pdf_name']
|
44 |
+
gt_data[pdf_name] = markdown
|
45 |
+
if data['language'] == 'en':
|
46 |
+
filename_list_en.append(pdf_name)
|
47 |
+
else:
|
48 |
+
filename_list_zh.append(pdf_name)
|
49 |
+
|
50 |
+
keys = list(gt_data.keys())
|
51 |
+
if args.n_jobs == 1:
|
52 |
+
scores = [evaluate(pred_data.get(filename, ''), gt_data.get(filename, '')) for filename in tqdm(keys)]
|
53 |
+
else:
|
54 |
+
inputs = [{'pred': pred_data.get(filename, ''), 'gt': gt_data.get(filename, '')} for filename in keys]
|
55 |
+
scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
|
56 |
+
|
57 |
+
total_score_en = 0
|
58 |
+
total_num_en = 0
|
59 |
+
total_score_zh = 0
|
60 |
+
total_num_zh = 0
|
61 |
+
for filename, score in zip(keys, scores):
|
62 |
+
print(filename)
|
63 |
+
print(score)
|
64 |
+
print()
|
65 |
+
if filename in filename_list_en:
|
66 |
+
total_score_en += score
|
67 |
+
total_num_en += 1
|
68 |
+
elif filename in filename_list_zh:
|
69 |
+
total_score_zh += score
|
70 |
+
total_num_zh += 1
|
71 |
+
print(f"English: {total_score_en / total_num_en}")
|
72 |
+
print(f"Chinese: {total_score_zh / total_num_zh}")
|
73 |
+
print(f"Total: {sum(scores) / len(scores)}")
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
main()
|
eval/eval_page_to_markdown_nanonets.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
import nltk
|
6 |
+
import markdown2
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
from tqdm import tqdm
|
9 |
+
from eval.parallel import parallel_process
|
10 |
+
|
11 |
+
def turn_header_to_h1(line):
|
12 |
+
# 检查是否是以一个或多个 '#' 开头的标题行
|
13 |
+
if line.lstrip().startswith('#'):
|
14 |
+
# 去掉开头的 '#' 和其后的空格
|
15 |
+
new_line = "# " + line.lstrip().lstrip('#').lstrip()
|
16 |
+
return new_line
|
17 |
+
else:
|
18 |
+
return line
|
19 |
+
|
20 |
+
def replace_single_dollar(markdown_text):
|
21 |
+
pattern = r'\$(.*?)\$'
|
22 |
+
def replace_with_brackets(match):
|
23 |
+
formula_content = match.group(1) # 获取匹配到的公式内容
|
24 |
+
return f'\\({formula_content}\\)'
|
25 |
+
|
26 |
+
replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
|
27 |
+
|
28 |
+
return replaced_text
|
29 |
+
|
30 |
+
|
31 |
+
def replace_double_dollar(markdown_text):
|
32 |
+
pattern = r'\$\$(.*?)\$\$'
|
33 |
+
def replace_with_brackets(match):
|
34 |
+
formula_content = match.group(1)
|
35 |
+
return f'\\[{formula_content}\\]'
|
36 |
+
replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
|
37 |
+
|
38 |
+
return replaced_text
|
39 |
+
|
40 |
+
def simplify_html_table(html_table):
|
41 |
+
# 使用 BeautifulSoup 解析 HTML
|
42 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
43 |
+
|
44 |
+
# 找到 <table> 标签
|
45 |
+
table = soup.find('table')
|
46 |
+
if not table:
|
47 |
+
raise ValueError("输入的 HTML 不包含有效的 <table> 标签")
|
48 |
+
|
49 |
+
# 创建一个新的 <table> 标签
|
50 |
+
new_table = BeautifulSoup('<table></table>', 'html.parser').table
|
51 |
+
|
52 |
+
# 提取所有行(包括 <thead> 和 <tbody> 中的行)
|
53 |
+
rows = table.find_all(['tr'], recursive=True)
|
54 |
+
|
55 |
+
for row in rows:
|
56 |
+
# 创建新的 <tr> 标签
|
57 |
+
new_row = soup.new_tag('tr')
|
58 |
+
|
59 |
+
# 处理每一行中的单元格
|
60 |
+
cells = row.find_all(['th', 'td'])
|
61 |
+
for cell in cells:
|
62 |
+
# 将 <th> 替换为 <td>
|
63 |
+
new_cell = soup.new_tag('td')
|
64 |
+
if cell.has_attr('rowspan'):
|
65 |
+
new_cell['rowspan'] = cell['rowspan']
|
66 |
+
if cell.has_attr('colspan'):
|
67 |
+
new_cell['colspan'] = cell['colspan']
|
68 |
+
new_cell.string = cell.get_text(strip=True) # 保留单元格内容
|
69 |
+
new_row.append(new_cell)
|
70 |
+
|
71 |
+
# 将新行添加到新表格中
|
72 |
+
new_table.append(new_row)
|
73 |
+
|
74 |
+
# 返回简化后的表格 HTML
|
75 |
+
return str(new_table)
|
76 |
+
|
77 |
+
def evaluate(pred, gt):
|
78 |
+
edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
|
79 |
+
return 1.0- edit_dist
|
80 |
+
|
81 |
+
|
82 |
+
def main():
|
83 |
+
parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
|
84 |
+
parser.add_argument(
|
85 |
+
"workspace",
|
86 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--gt_file",
|
90 |
+
help="Ground truth file",
|
91 |
+
)
|
92 |
+
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
|
93 |
+
args = parser.parse_args()
|
94 |
+
|
95 |
+
pred_data = {}
|
96 |
+
for file in os.listdir(args.workspace):
|
97 |
+
file_path = os.path.join(args.workspace, file)
|
98 |
+
pdf_name = file.split('.')[0] + ".pdf"
|
99 |
+
with open(file_path, "r") as f:
|
100 |
+
document_text = f.read()
|
101 |
+
document_text = replace_single_dollar(replace_double_dollar(document_text))
|
102 |
+
markdown_text_list = document_text.split("\n\n")
|
103 |
+
new_markdown_text_list = []
|
104 |
+
for text in markdown_text_list:
|
105 |
+
text = text.strip()
|
106 |
+
if (text.startswith("<watermark>") and text.endswith("</watermark>")) or (text.startswith("<img>") and text.endswith("</img>")) or (text.startswith("<page_number>") and text.endswith("</page_number>")) or (text.startswith("<signature>") and text.endswith("</signature>")):
|
107 |
+
continue
|
108 |
+
else:
|
109 |
+
html_text = str(markdown2.markdown(text,extras=["tables"]))
|
110 |
+
html_text = html_text.strip()
|
111 |
+
if html_text.startswith("<table>") and html_text.endswith("</table>"):
|
112 |
+
html_table = simplify_html_table(html_text)
|
113 |
+
new_markdown_text_list.append(html_table)
|
114 |
+
else:
|
115 |
+
text = turn_header_to_h1(text)
|
116 |
+
new_markdown_text_list.append(text)
|
117 |
+
|
118 |
+
pred_data[os.path.basename(pdf_name)] = "\n\n".join(new_markdown_text_list)
|
119 |
+
|
120 |
+
filename_list_en = []
|
121 |
+
filename_list_zh = []
|
122 |
+
gt_data = {}
|
123 |
+
with open(args.gt_file, "r") as f:
|
124 |
+
for line in f:
|
125 |
+
data = json.loads(line)
|
126 |
+
markdown = data['markdown']
|
127 |
+
pdf_name = data['pdf_name']
|
128 |
+
gt_data[pdf_name] = markdown
|
129 |
+
if data['language'] == 'en':
|
130 |
+
filename_list_en.append(pdf_name)
|
131 |
+
else:
|
132 |
+
filename_list_zh.append(pdf_name)
|
133 |
+
|
134 |
+
keys = list(gt_data.keys())
|
135 |
+
if args.n_jobs == 1:
|
136 |
+
scores = [evaluate(pred_data.get(filename, ''), gt_data.get(filename, '')) for filename in tqdm(keys)]
|
137 |
+
else:
|
138 |
+
inputs = [{'pred': pred_data.get(filename, ''), 'gt': gt_data.get(filename, '')} for filename in keys]
|
139 |
+
scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
|
140 |
+
|
141 |
+
total_score_en = 0
|
142 |
+
total_num_en = 0
|
143 |
+
total_score_zh = 0
|
144 |
+
total_num_zh = 0
|
145 |
+
for filename, score in zip(keys, scores):
|
146 |
+
if filename in filename_list_en:
|
147 |
+
print(filename)
|
148 |
+
print(score)
|
149 |
+
print()
|
150 |
+
total_score_en += score
|
151 |
+
total_num_en += 1
|
152 |
+
elif filename in filename_list_zh:
|
153 |
+
total_score_zh += score
|
154 |
+
total_num_zh += 1
|
155 |
+
print(f"English: {total_score_en / total_num_en}")
|
156 |
+
print(f"Chinese: {total_score_zh / total_num_zh}")
|
157 |
+
print(f"Total: {sum(scores) / len(scores)}")
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
main()
|
eval/eval_page_to_markdown_olmocr.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
import nltk
|
6 |
+
import markdown2
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
from tqdm import tqdm
|
9 |
+
from eval.parallel import parallel_process
|
10 |
+
|
11 |
+
def turn_header_to_h1(line):
|
12 |
+
# 检查是否是以一个或多个 '#' 开头的标题行
|
13 |
+
if line.lstrip().startswith('#'):
|
14 |
+
# 去掉开头的 '#' 和其后的空格
|
15 |
+
new_line = "# " + line.lstrip().lstrip('#').lstrip()
|
16 |
+
return new_line
|
17 |
+
else:
|
18 |
+
return line
|
19 |
+
|
20 |
+
def replace_single_dollar(markdown_text):
|
21 |
+
pattern = r'\$(.*?)\$'
|
22 |
+
def replace_with_brackets(match):
|
23 |
+
formula_content = match.group(1) # 获取匹配到的公式内容
|
24 |
+
return f'\\({formula_content}\\)'
|
25 |
+
|
26 |
+
replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
|
27 |
+
|
28 |
+
return replaced_text
|
29 |
+
|
30 |
+
|
31 |
+
def replace_double_dollar(markdown_text):
|
32 |
+
pattern = r'\$\$(.*?)\$\$'
|
33 |
+
def replace_with_brackets(match):
|
34 |
+
formula_content = match.group(1)
|
35 |
+
return f'\\[{formula_content}\\]'
|
36 |
+
replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
|
37 |
+
|
38 |
+
return replaced_text
|
39 |
+
|
40 |
+
def simplify_html_table(html_table):
|
41 |
+
# 使用 BeautifulSoup 解析 HTML
|
42 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
43 |
+
|
44 |
+
# 找到 <table> 标签
|
45 |
+
table = soup.find('table')
|
46 |
+
if not table:
|
47 |
+
raise ValueError("输入的 HTML 不包含有效的 <table> 标签")
|
48 |
+
|
49 |
+
# 创建一个新的 <table> 标签
|
50 |
+
new_table = BeautifulSoup('<table></table>', 'html.parser').table
|
51 |
+
|
52 |
+
# 提取所有行(包括 <thead> 和 <tbody> 中的行)
|
53 |
+
rows = table.find_all(['tr'], recursive=True)
|
54 |
+
|
55 |
+
for row in rows:
|
56 |
+
# 创建新的 <tr> 标签
|
57 |
+
new_row = soup.new_tag('tr')
|
58 |
+
|
59 |
+
# 处理每一行中的单元格
|
60 |
+
cells = row.find_all(['th', 'td'])
|
61 |
+
for cell in cells:
|
62 |
+
# 将 <th> 替换为 <td>
|
63 |
+
new_cell = soup.new_tag('td')
|
64 |
+
new_cell.string = cell.get_text(strip=True) # 保留单元格内容
|
65 |
+
new_row.append(new_cell)
|
66 |
+
|
67 |
+
# 将新行添加到新表格中
|
68 |
+
new_table.append(new_row)
|
69 |
+
|
70 |
+
# 返回简化后的表格 HTML
|
71 |
+
return str(new_table)
|
72 |
+
|
73 |
+
def evaluate(pred, gt):
|
74 |
+
edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
|
75 |
+
return 1.0- edit_dist
|
76 |
+
|
77 |
+
|
78 |
+
def main():
|
79 |
+
parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
|
80 |
+
parser.add_argument(
|
81 |
+
"workspace",
|
82 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--gt_file",
|
86 |
+
help="Ground truth file",
|
87 |
+
)
|
88 |
+
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
|
89 |
+
args = parser.parse_args()
|
90 |
+
|
91 |
+
pred_data = {}
|
92 |
+
root_dir = os.path.join(args.workspace, "results")
|
93 |
+
for jsonl_file in os.listdir(root_dir):
|
94 |
+
if jsonl_file.endswith(".jsonl"):
|
95 |
+
with open(os.path.join(root_dir, jsonl_file), "r") as f:
|
96 |
+
for line in f:
|
97 |
+
data = json.loads(line)
|
98 |
+
pdf_path = data['metadata']['Source-File']
|
99 |
+
document_text = data['text']
|
100 |
+
document_text = replace_single_dollar(replace_double_dollar(document_text))
|
101 |
+
|
102 |
+
markdown_text_list = document_text.split("\n\n")
|
103 |
+
|
104 |
+
new_markdown_text_list = []
|
105 |
+
for text in markdown_text_list:
|
106 |
+
html_text = str(markdown2.markdown(text,extras=["tables"]))
|
107 |
+
html_text = html_text.strip()
|
108 |
+
if html_text.startswith("<table>") and html_text.endswith("</table>"):
|
109 |
+
html_table = simplify_html_table(html_text)
|
110 |
+
new_markdown_text_list.append(html_table)
|
111 |
+
else:
|
112 |
+
text = turn_header_to_h1(text)
|
113 |
+
new_markdown_text_list.append(text)
|
114 |
+
|
115 |
+
pred_data[os.path.basename(pdf_path)] = "\n\n".join(new_markdown_text_list)
|
116 |
+
|
117 |
+
filename_list_en = []
|
118 |
+
filename_list_zh = []
|
119 |
+
gt_data = {}
|
120 |
+
with open(args.gt_file, "r") as f:
|
121 |
+
for line in f:
|
122 |
+
data = json.loads(line)
|
123 |
+
markdown = data['markdown']
|
124 |
+
pdf_name = data['pdf_name']
|
125 |
+
gt_data[pdf_name] = markdown
|
126 |
+
if data['language'] == 'en':
|
127 |
+
filename_list_en.append(pdf_name)
|
128 |
+
else:
|
129 |
+
filename_list_zh.append(pdf_name)
|
130 |
+
|
131 |
+
keys = list(gt_data.keys())
|
132 |
+
if args.n_jobs == 1:
|
133 |
+
scores = [evaluate(pred_data.get(filename, ''), gt_data.get(filename, '')) for filename in tqdm(keys)]
|
134 |
+
else:
|
135 |
+
inputs = [{'pred': pred_data.get(filename, ''), 'gt': gt_data.get(filename, '')} for filename in keys]
|
136 |
+
scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
|
137 |
+
|
138 |
+
total_score_en = 0
|
139 |
+
total_num_en = 0
|
140 |
+
total_score_zh = 0
|
141 |
+
total_num_zh = 0
|
142 |
+
for filename, score in zip(keys, scores):
|
143 |
+
if filename in filename_list_en:
|
144 |
+
print(filename)
|
145 |
+
print(score)
|
146 |
+
print()
|
147 |
+
total_score_en += score
|
148 |
+
total_num_en += 1
|
149 |
+
elif filename in filename_list_zh:
|
150 |
+
total_score_zh += score
|
151 |
+
total_num_zh += 1
|
152 |
+
print(f"English: {total_score_en / total_num_en}")
|
153 |
+
print(f"Chinese: {total_score_zh / total_num_zh}")
|
154 |
+
print(f"Total: {sum(scores) / len(scores)}")
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
main()
|
eval/eval_table_to_html.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import distance
|
5 |
+
from apted import APTED, Config
|
6 |
+
from apted.helpers import Tree
|
7 |
+
from lxml import etree, html
|
8 |
+
from collections import deque
|
9 |
+
from tqdm import tqdm
|
10 |
+
from eval.parallel import parallel_process
|
11 |
+
|
12 |
+
|
13 |
+
class TableTree(Tree):
|
14 |
+
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
15 |
+
self.tag = tag
|
16 |
+
self.colspan = colspan
|
17 |
+
self.rowspan = rowspan
|
18 |
+
self.content = content
|
19 |
+
self.children = list(children)
|
20 |
+
|
21 |
+
def bracket(self):
|
22 |
+
"""Show tree using brackets notation"""
|
23 |
+
if self.tag == 'td':
|
24 |
+
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
25 |
+
(self.tag, self.colspan, self.rowspan, self.content)
|
26 |
+
else:
|
27 |
+
result = '"tag": %s' % self.tag
|
28 |
+
for child in self.children:
|
29 |
+
result += child.bracket()
|
30 |
+
return "{{{}}}".format(result)
|
31 |
+
|
32 |
+
|
33 |
+
class CustomConfig(Config):
|
34 |
+
@staticmethod
|
35 |
+
def maximum(*sequences):
|
36 |
+
"""Get maximum possible value
|
37 |
+
"""
|
38 |
+
return max(map(len, sequences))
|
39 |
+
|
40 |
+
def normalized_distance(self, *sequences):
|
41 |
+
"""Get distance from 0 to 1
|
42 |
+
"""
|
43 |
+
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
44 |
+
|
45 |
+
def rename(self, node1, node2):
|
46 |
+
"""Compares attributes of trees"""
|
47 |
+
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
48 |
+
return 1.
|
49 |
+
if node1.tag == 'td':
|
50 |
+
if node1.content or node2.content:
|
51 |
+
return self.normalized_distance(node1.content, node2.content)
|
52 |
+
return 0.
|
53 |
+
|
54 |
+
|
55 |
+
class TEDS(object):
|
56 |
+
''' Tree Edit Distance basead Similarity
|
57 |
+
'''
|
58 |
+
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
59 |
+
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
60 |
+
self.structure_only = structure_only
|
61 |
+
self.n_jobs = n_jobs
|
62 |
+
self.ignore_nodes = ignore_nodes
|
63 |
+
self.__tokens__ = []
|
64 |
+
|
65 |
+
def tokenize(self, node):
|
66 |
+
''' Tokenizes table cells
|
67 |
+
'''
|
68 |
+
self.__tokens__.append('<%s>' % node.tag)
|
69 |
+
if node.text is not None:
|
70 |
+
self.__tokens__ += list(node.text)
|
71 |
+
for n in node.getchildren():
|
72 |
+
self.tokenize(n)
|
73 |
+
if node.tag != 'unk':
|
74 |
+
self.__tokens__.append('</%s>' % node.tag)
|
75 |
+
if node.tag != 'td' and node.tail is not None:
|
76 |
+
self.__tokens__ += list(node.tail)
|
77 |
+
|
78 |
+
def load_html_tree(self, node, parent=None):
|
79 |
+
''' Converts HTML tree to the format required by apted
|
80 |
+
'''
|
81 |
+
global __tokens__
|
82 |
+
if node.tag == 'td':
|
83 |
+
if self.structure_only:
|
84 |
+
cell = []
|
85 |
+
else:
|
86 |
+
self.__tokens__ = []
|
87 |
+
self.tokenize(node)
|
88 |
+
cell = self.__tokens__[1:-1].copy()
|
89 |
+
new_node = TableTree(node.tag,
|
90 |
+
int(node.attrib.get('colspan', '1')),
|
91 |
+
int(node.attrib.get('rowspan', '1')),
|
92 |
+
cell, *deque())
|
93 |
+
else:
|
94 |
+
new_node = TableTree(node.tag, None, None, None, *deque())
|
95 |
+
if parent is not None:
|
96 |
+
parent.children.append(new_node)
|
97 |
+
if node.tag != 'td':
|
98 |
+
for n in node.getchildren():
|
99 |
+
self.load_html_tree(n, new_node)
|
100 |
+
if parent is None:
|
101 |
+
return new_node
|
102 |
+
|
103 |
+
def evaluate(self, pred, true):
|
104 |
+
''' Computes TEDS score between the prediction and the ground truth of a
|
105 |
+
given sample
|
106 |
+
'''
|
107 |
+
if (not pred) or (not true):
|
108 |
+
return 0.0
|
109 |
+
pred = "<html>" + pred + "</html>"
|
110 |
+
true = "<html>" + true + "</html>"
|
111 |
+
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
112 |
+
pred = html.fromstring(pred, parser=parser)
|
113 |
+
true = html.fromstring(true, parser=parser)
|
114 |
+
if pred.xpath('body/table') and true.xpath('body/table'):
|
115 |
+
pred = pred.xpath('body/table')[0]
|
116 |
+
true = true.xpath('body/table')[0]
|
117 |
+
if self.ignore_nodes:
|
118 |
+
etree.strip_tags(pred, *self.ignore_nodes)
|
119 |
+
etree.strip_tags(true, *self.ignore_nodes)
|
120 |
+
n_nodes_pred = len(pred.xpath(".//*"))
|
121 |
+
n_nodes_true = len(true.xpath(".//*"))
|
122 |
+
n_nodes = max(n_nodes_pred, n_nodes_true)
|
123 |
+
tree_pred = self.load_html_tree(pred)
|
124 |
+
tree_true = self.load_html_tree(true)
|
125 |
+
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
|
126 |
+
return 1.0 - (float(distance) / n_nodes)
|
127 |
+
else:
|
128 |
+
return 0.0
|
129 |
+
|
130 |
+
def batch_evaluate(self, pred_json, true_json):
|
131 |
+
''' Computes TEDS score between the prediction and the ground truth of
|
132 |
+
a batch of samples
|
133 |
+
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
134 |
+
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
135 |
+
@output: {'FILENAME': 'TEDS SCORE', ...}
|
136 |
+
'''
|
137 |
+
samples = true_json.keys()
|
138 |
+
if self.n_jobs == 1:
|
139 |
+
scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
140 |
+
else:
|
141 |
+
inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
142 |
+
scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
143 |
+
total_score_simple = 0
|
144 |
+
num_simple = 0
|
145 |
+
total_score_complex = 0
|
146 |
+
num_complex = 0
|
147 |
+
total_score = 0
|
148 |
+
num_total = 0
|
149 |
+
for filename,score in zip(samples, scores):
|
150 |
+
print(filename)
|
151 |
+
print(score)
|
152 |
+
print('')
|
153 |
+
if true_json[filename]['type'] == 'simple':
|
154 |
+
total_score_simple += score
|
155 |
+
num_simple += 1
|
156 |
+
elif true_json[filename]['type'] == 'complex':
|
157 |
+
total_score_complex += score
|
158 |
+
num_complex += 1
|
159 |
+
else:
|
160 |
+
raise ValueError('Unknown type: %s' % true_json[filename]['type'])
|
161 |
+
total_score += score
|
162 |
+
num_total += 1
|
163 |
+
if num_simple > 0:
|
164 |
+
avg_score_simple = total_score_simple / num_simple
|
165 |
+
else:
|
166 |
+
avg_score_simple = 0
|
167 |
+
if num_complex > 0:
|
168 |
+
avg_score_complex = total_score_complex / num_complex
|
169 |
+
else:
|
170 |
+
avg_score_complex = 0
|
171 |
+
avg_score = total_score / num_total
|
172 |
+
print({'simple': (num_simple,avg_score_simple), 'complex': (num_complex,avg_score_complex), 'total': (num_total,avg_score)})
|
173 |
+
|
174 |
+
def main():
|
175 |
+
parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
|
176 |
+
parser.add_argument(
|
177 |
+
"workspace",
|
178 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--gt_file",
|
182 |
+
help="Ground truth file",
|
183 |
+
)
|
184 |
+
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
|
185 |
+
args = parser.parse_args()
|
186 |
+
|
187 |
+
pred_data = {}
|
188 |
+
root_dir = os.path.join(args.workspace, "results")
|
189 |
+
for jsonl_file in os.listdir(root_dir):
|
190 |
+
if jsonl_file.endswith(".jsonl"):
|
191 |
+
with open(os.path.join(root_dir, jsonl_file), "r") as f:
|
192 |
+
for line in f:
|
193 |
+
data = json.loads(line)
|
194 |
+
pred_data[os.path.basename(data['orig_path'])] = data['document_text']
|
195 |
+
|
196 |
+
gt_data = {}
|
197 |
+
with open(args.gt_file, "r") as f:
|
198 |
+
for line in f:
|
199 |
+
data = json.loads(line)
|
200 |
+
gt_data[data['image_name']] = {'html':data['gt_table'], 'type':data['type']}
|
201 |
+
|
202 |
+
teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
|
203 |
+
teds.batch_evaluate(pred_data, gt_data)
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
main()
|
eval/eval_table_to_html_nanonets.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import distance
|
5 |
+
import markdown2
|
6 |
+
import re
|
7 |
+
from apted import APTED, Config
|
8 |
+
from apted.helpers import Tree
|
9 |
+
from lxml import etree, html
|
10 |
+
from collections import deque
|
11 |
+
from tqdm import tqdm
|
12 |
+
from eval.parallel import parallel_process
|
13 |
+
from bs4 import BeautifulSoup
|
14 |
+
|
15 |
+
def turn_header_to_h1(line):
|
16 |
+
# 检查是否是以一个或多个 '#' 开头的标题行
|
17 |
+
if line.lstrip().startswith('#'):
|
18 |
+
# 去掉开头的 '#' 和其后的空格
|
19 |
+
new_line = "# " + line.lstrip().lstrip('#').lstrip()
|
20 |
+
return new_line
|
21 |
+
else:
|
22 |
+
return line
|
23 |
+
|
24 |
+
def replace_single_dollar(markdown_text):
|
25 |
+
pattern = r'\$(.*?)\$'
|
26 |
+
def replace_with_brackets(match):
|
27 |
+
formula_content = match.group(1) # 获取匹配到的公式内容
|
28 |
+
return f'\\({formula_content}\\)'
|
29 |
+
|
30 |
+
replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
|
31 |
+
|
32 |
+
return replaced_text
|
33 |
+
|
34 |
+
|
35 |
+
def replace_double_dollar(markdown_text):
|
36 |
+
pattern = r'\$\$(.*?)\$\$'
|
37 |
+
def replace_with_brackets(match):
|
38 |
+
formula_content = match.group(1)
|
39 |
+
return f'\\[{formula_content}\\]'
|
40 |
+
replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
|
41 |
+
|
42 |
+
return replaced_text
|
43 |
+
|
44 |
+
|
45 |
+
class TableTree(Tree):
|
46 |
+
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
47 |
+
self.tag = tag
|
48 |
+
self.colspan = colspan
|
49 |
+
self.rowspan = rowspan
|
50 |
+
self.content = content
|
51 |
+
self.children = list(children)
|
52 |
+
|
53 |
+
def bracket(self):
|
54 |
+
"""Show tree using brackets notation"""
|
55 |
+
if self.tag == 'td':
|
56 |
+
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
57 |
+
(self.tag, self.colspan, self.rowspan, self.content)
|
58 |
+
else:
|
59 |
+
result = '"tag": %s' % self.tag
|
60 |
+
for child in self.children:
|
61 |
+
result += child.bracket()
|
62 |
+
return "{{{}}}".format(result)
|
63 |
+
|
64 |
+
|
65 |
+
class CustomConfig(Config):
|
66 |
+
@staticmethod
|
67 |
+
def maximum(*sequences):
|
68 |
+
"""Get maximum possible value
|
69 |
+
"""
|
70 |
+
return max(map(len, sequences))
|
71 |
+
|
72 |
+
def normalized_distance(self, *sequences):
|
73 |
+
"""Get distance from 0 to 1
|
74 |
+
"""
|
75 |
+
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
76 |
+
|
77 |
+
def rename(self, node1, node2):
|
78 |
+
"""Compares attributes of trees"""
|
79 |
+
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
80 |
+
return 1.
|
81 |
+
if node1.tag == 'td':
|
82 |
+
if node1.content or node2.content:
|
83 |
+
return self.normalized_distance(node1.content, node2.content)
|
84 |
+
return 0.
|
85 |
+
|
86 |
+
|
87 |
+
class TEDS(object):
|
88 |
+
''' Tree Edit Distance basead Similarity
|
89 |
+
'''
|
90 |
+
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
91 |
+
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
92 |
+
self.structure_only = structure_only
|
93 |
+
self.n_jobs = n_jobs
|
94 |
+
self.ignore_nodes = ignore_nodes
|
95 |
+
self.__tokens__ = []
|
96 |
+
|
97 |
+
def tokenize(self, node):
|
98 |
+
''' Tokenizes table cells
|
99 |
+
'''
|
100 |
+
self.__tokens__.append('<%s>' % node.tag)
|
101 |
+
if node.text is not None:
|
102 |
+
self.__tokens__ += list(node.text)
|
103 |
+
for n in node.getchildren():
|
104 |
+
self.tokenize(n)
|
105 |
+
if node.tag != 'unk':
|
106 |
+
self.__tokens__.append('</%s>' % node.tag)
|
107 |
+
if node.tag != 'td' and node.tail is not None:
|
108 |
+
self.__tokens__ += list(node.tail)
|
109 |
+
|
110 |
+
def load_html_tree(self, node, parent=None):
|
111 |
+
''' Converts HTML tree to the format required by apted
|
112 |
+
'''
|
113 |
+
global __tokens__
|
114 |
+
if node.tag == 'td':
|
115 |
+
if self.structure_only:
|
116 |
+
cell = []
|
117 |
+
else:
|
118 |
+
self.__tokens__ = []
|
119 |
+
self.tokenize(node)
|
120 |
+
cell = self.__tokens__[1:-1].copy()
|
121 |
+
new_node = TableTree(node.tag,
|
122 |
+
int(node.attrib.get('colspan', '1')),
|
123 |
+
int(node.attrib.get('rowspan', '1')),
|
124 |
+
cell, *deque())
|
125 |
+
else:
|
126 |
+
new_node = TableTree(node.tag, None, None, None, *deque())
|
127 |
+
if parent is not None:
|
128 |
+
parent.children.append(new_node)
|
129 |
+
if node.tag != 'td':
|
130 |
+
for n in node.getchildren():
|
131 |
+
self.load_html_tree(n, new_node)
|
132 |
+
if parent is None:
|
133 |
+
return new_node
|
134 |
+
|
135 |
+
def evaluate(self, pred, true):
|
136 |
+
''' Computes TEDS score between the prediction and the ground truth of a
|
137 |
+
given sample
|
138 |
+
'''
|
139 |
+
if (not pred) or (not true):
|
140 |
+
return 0.0
|
141 |
+
pred.replace("<th>","<td>")
|
142 |
+
pred.replace("</th>","</td>")
|
143 |
+
pred = "<html>" + pred + "</html>"
|
144 |
+
true = "<html>" + true + "</html>"
|
145 |
+
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
146 |
+
pred = html.fromstring(pred, parser=parser)
|
147 |
+
true = html.fromstring(true, parser=parser)
|
148 |
+
if pred.xpath('body/table') and true.xpath('body/table'):
|
149 |
+
pred = pred.xpath('body/table')[0]
|
150 |
+
true = true.xpath('body/table')[0]
|
151 |
+
if self.ignore_nodes:
|
152 |
+
etree.strip_tags(pred, *self.ignore_nodes)
|
153 |
+
etree.strip_tags(true, *self.ignore_nodes)
|
154 |
+
n_nodes_pred = len(pred.xpath(".//*"))
|
155 |
+
n_nodes_true = len(true.xpath(".//*"))
|
156 |
+
n_nodes = max(n_nodes_pred, n_nodes_true)
|
157 |
+
tree_pred = self.load_html_tree(pred)
|
158 |
+
tree_true = self.load_html_tree(true)
|
159 |
+
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
|
160 |
+
return 1.0 - (float(distance) / n_nodes)
|
161 |
+
else:
|
162 |
+
return 0.0
|
163 |
+
|
164 |
+
def batch_evaluate(self, pred_json, true_json):
|
165 |
+
''' Computes TEDS score between the prediction and the ground truth of
|
166 |
+
a batch of samples
|
167 |
+
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
168 |
+
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
169 |
+
@output: {'FILENAME': 'TEDS SCORE', ...}
|
170 |
+
'''
|
171 |
+
samples = true_json.keys()
|
172 |
+
if self.n_jobs == 1:
|
173 |
+
scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
174 |
+
else:
|
175 |
+
inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
176 |
+
scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
177 |
+
total_score_simple = 0
|
178 |
+
num_simple = 0
|
179 |
+
total_score_complex = 0
|
180 |
+
num_complex = 0
|
181 |
+
total_score = 0
|
182 |
+
num_total = 0
|
183 |
+
for filename,score in zip(samples, scores):
|
184 |
+
print(filename)
|
185 |
+
print(score)
|
186 |
+
print('')
|
187 |
+
if true_json[filename]['type'] == 'simple':
|
188 |
+
total_score_simple += score
|
189 |
+
num_simple += 1
|
190 |
+
elif true_json[filename]['type'] == 'complex':
|
191 |
+
total_score_complex += score
|
192 |
+
num_complex += 1
|
193 |
+
else:
|
194 |
+
raise ValueError('Unknown type: %s' % true_json[filename]['type'])
|
195 |
+
total_score += score
|
196 |
+
num_total += 1
|
197 |
+
if num_simple > 0:
|
198 |
+
avg_score_simple = total_score_simple / num_simple
|
199 |
+
else:
|
200 |
+
avg_score_simple = 0
|
201 |
+
if num_complex > 0:
|
202 |
+
avg_score_complex = total_score_complex / num_complex
|
203 |
+
else:
|
204 |
+
avg_score_complex = 0
|
205 |
+
avg_score = total_score / num_total
|
206 |
+
print({'simple': (num_simple,avg_score_simple), 'complex': (num_complex,avg_score_complex), 'total': (num_total,avg_score)})
|
207 |
+
|
208 |
+
def simplify_html_table(html_table):
|
209 |
+
# 使用 BeautifulSoup 解析 HTML
|
210 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
211 |
+
|
212 |
+
# 找到 <table> 标签
|
213 |
+
table = soup.find('table')
|
214 |
+
if not table:
|
215 |
+
raise ValueError("输入的 HTML 不包含有效的 <table> 标签")
|
216 |
+
|
217 |
+
# 创建一个新的 <table> 标签
|
218 |
+
new_table = BeautifulSoup('<table></table>', 'html.parser').table
|
219 |
+
|
220 |
+
# 提取所有行(包括 <thead> 和 <tbody> 中的行)
|
221 |
+
rows = table.find_all(['tr'], recursive=True)
|
222 |
+
|
223 |
+
for row in rows:
|
224 |
+
# 创建新的 <tr> 标签
|
225 |
+
new_row = soup.new_tag('tr')
|
226 |
+
|
227 |
+
# 处理每一行中的单元格
|
228 |
+
cells = row.find_all(['th', 'td'])
|
229 |
+
for cell in cells:
|
230 |
+
# 将 <th> 替换为 <td>
|
231 |
+
new_cell = soup.new_tag('td')
|
232 |
+
if cell.has_attr('rowspan'):
|
233 |
+
new_cell['rowspan'] = cell['rowspan']
|
234 |
+
if cell.has_attr('colspan'):
|
235 |
+
new_cell['colspan'] = cell['colspan']
|
236 |
+
new_cell.string = cell.get_text(strip=True) # 保留单元格内容
|
237 |
+
new_row.append(new_cell)
|
238 |
+
|
239 |
+
# 将新行添加到新表格中
|
240 |
+
new_table.append(new_row)
|
241 |
+
|
242 |
+
# 返回简化后的表格 HTML
|
243 |
+
return str(new_table)
|
244 |
+
|
245 |
+
|
246 |
+
def main():
|
247 |
+
parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
|
248 |
+
parser.add_argument(
|
249 |
+
"workspace",
|
250 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
251 |
+
)
|
252 |
+
parser.add_argument(
|
253 |
+
"--gt_file",
|
254 |
+
help="Ground truth file",
|
255 |
+
)
|
256 |
+
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
|
257 |
+
args = parser.parse_args()
|
258 |
+
|
259 |
+
pred_data = {}
|
260 |
+
for file in os.listdir(args.workspace):
|
261 |
+
file_path = os.path.join(args.workspace, file)
|
262 |
+
pdf_name = file.split('.')[0] + ".png"
|
263 |
+
with open(file_path, "r") as f:
|
264 |
+
document_text = f.read()
|
265 |
+
document_text = replace_single_dollar(replace_double_dollar(document_text))
|
266 |
+
markdown_text_list = document_text.split("\n\n")
|
267 |
+
new_markdown_text_list = []
|
268 |
+
for text in markdown_text_list:
|
269 |
+
text = text.strip()
|
270 |
+
if (text.startswith("<watermark>") and text.endswith("</watermark>")) or (text.startswith("<img>") and text.endswith("</img>")) or (text.startswith("<page_number>") and text.endswith("</page_number>")) or (text.startswith("<signature>") and text.endswith("</signature>")):
|
271 |
+
continue
|
272 |
+
else:
|
273 |
+
html_text = str(markdown2.markdown(text,extras=["tables"]))
|
274 |
+
html_text = html_text.strip()
|
275 |
+
if html_text.startswith("<table>") and html_text.endswith("</table>"):
|
276 |
+
html_table = simplify_html_table(html_text)
|
277 |
+
new_markdown_text_list.append(html_table)
|
278 |
+
else:
|
279 |
+
text = turn_header_to_h1(text)
|
280 |
+
new_markdown_text_list.append(text)
|
281 |
+
|
282 |
+
pred_data[os.path.basename(pdf_name)] = "\n\n".join(new_markdown_text_list)
|
283 |
+
|
284 |
+
|
285 |
+
gt_data = {}
|
286 |
+
with open(args.gt_file, "r") as f:
|
287 |
+
for line in f:
|
288 |
+
data = json.loads(line)
|
289 |
+
gt_data[data['image_name']] = {'html':data['gt_table'], 'type':data['type']}
|
290 |
+
|
291 |
+
teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
|
292 |
+
teds.batch_evaluate(pred_data, gt_data)
|
293 |
+
|
294 |
+
if __name__ == "__main__":
|
295 |
+
main()
|
eval/eval_table_to_html_olmocr.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import distance
|
5 |
+
import markdown2
|
6 |
+
from apted import APTED, Config
|
7 |
+
from apted.helpers import Tree
|
8 |
+
from lxml import etree, html
|
9 |
+
from collections import deque
|
10 |
+
from tqdm import tqdm
|
11 |
+
from eval.parallel import parallel_process
|
12 |
+
|
13 |
+
|
14 |
+
class TableTree(Tree):
|
15 |
+
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
16 |
+
self.tag = tag
|
17 |
+
self.colspan = colspan
|
18 |
+
self.rowspan = rowspan
|
19 |
+
self.content = content
|
20 |
+
self.children = list(children)
|
21 |
+
|
22 |
+
def bracket(self):
|
23 |
+
"""Show tree using brackets notation"""
|
24 |
+
if self.tag == 'td':
|
25 |
+
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
26 |
+
(self.tag, self.colspan, self.rowspan, self.content)
|
27 |
+
else:
|
28 |
+
result = '"tag": %s' % self.tag
|
29 |
+
for child in self.children:
|
30 |
+
result += child.bracket()
|
31 |
+
return "{{{}}}".format(result)
|
32 |
+
|
33 |
+
|
34 |
+
class CustomConfig(Config):
|
35 |
+
@staticmethod
|
36 |
+
def maximum(*sequences):
|
37 |
+
"""Get maximum possible value
|
38 |
+
"""
|
39 |
+
return max(map(len, sequences))
|
40 |
+
|
41 |
+
def normalized_distance(self, *sequences):
|
42 |
+
"""Get distance from 0 to 1
|
43 |
+
"""
|
44 |
+
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
45 |
+
|
46 |
+
def rename(self, node1, node2):
|
47 |
+
"""Compares attributes of trees"""
|
48 |
+
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
49 |
+
return 1.
|
50 |
+
if node1.tag == 'td':
|
51 |
+
if node1.content or node2.content:
|
52 |
+
return self.normalized_distance(node1.content, node2.content)
|
53 |
+
return 0.
|
54 |
+
|
55 |
+
|
56 |
+
class TEDS(object):
|
57 |
+
''' Tree Edit Distance basead Similarity
|
58 |
+
'''
|
59 |
+
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
60 |
+
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
61 |
+
self.structure_only = structure_only
|
62 |
+
self.n_jobs = n_jobs
|
63 |
+
self.ignore_nodes = ignore_nodes
|
64 |
+
self.__tokens__ = []
|
65 |
+
|
66 |
+
def tokenize(self, node):
|
67 |
+
''' Tokenizes table cells
|
68 |
+
'''
|
69 |
+
self.__tokens__.append('<%s>' % node.tag)
|
70 |
+
if node.text is not None:
|
71 |
+
self.__tokens__ += list(node.text)
|
72 |
+
for n in node.getchildren():
|
73 |
+
self.tokenize(n)
|
74 |
+
if node.tag != 'unk':
|
75 |
+
self.__tokens__.append('</%s>' % node.tag)
|
76 |
+
if node.tag != 'td' and node.tail is not None:
|
77 |
+
self.__tokens__ += list(node.tail)
|
78 |
+
|
79 |
+
def load_html_tree(self, node, parent=None):
|
80 |
+
''' Converts HTML tree to the format required by apted
|
81 |
+
'''
|
82 |
+
global __tokens__
|
83 |
+
if node.tag == 'td':
|
84 |
+
if self.structure_only:
|
85 |
+
cell = []
|
86 |
+
else:
|
87 |
+
self.__tokens__ = []
|
88 |
+
self.tokenize(node)
|
89 |
+
cell = self.__tokens__[1:-1].copy()
|
90 |
+
new_node = TableTree(node.tag,
|
91 |
+
int(node.attrib.get('colspan', '1')),
|
92 |
+
int(node.attrib.get('rowspan', '1')),
|
93 |
+
cell, *deque())
|
94 |
+
else:
|
95 |
+
new_node = TableTree(node.tag, None, None, None, *deque())
|
96 |
+
if parent is not None:
|
97 |
+
parent.children.append(new_node)
|
98 |
+
if node.tag != 'td':
|
99 |
+
for n in node.getchildren():
|
100 |
+
self.load_html_tree(n, new_node)
|
101 |
+
if parent is None:
|
102 |
+
return new_node
|
103 |
+
|
104 |
+
def evaluate(self, pred, true):
|
105 |
+
''' Computes TEDS score between the prediction and the ground truth of a
|
106 |
+
given sample
|
107 |
+
'''
|
108 |
+
if (not pred) or (not true):
|
109 |
+
return 0.0
|
110 |
+
pred.replace("<th>","<td>")
|
111 |
+
pred.replace("</th>","</td>")
|
112 |
+
pred = "<html>" + pred + "</html>"
|
113 |
+
true = "<html>" + true + "</html>"
|
114 |
+
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
115 |
+
pred = html.fromstring(pred, parser=parser)
|
116 |
+
true = html.fromstring(true, parser=parser)
|
117 |
+
if pred.xpath('body/table') and true.xpath('body/table'):
|
118 |
+
pred = pred.xpath('body/table')[0]
|
119 |
+
true = true.xpath('body/table')[0]
|
120 |
+
if self.ignore_nodes:
|
121 |
+
etree.strip_tags(pred, *self.ignore_nodes)
|
122 |
+
etree.strip_tags(true, *self.ignore_nodes)
|
123 |
+
n_nodes_pred = len(pred.xpath(".//*"))
|
124 |
+
n_nodes_true = len(true.xpath(".//*"))
|
125 |
+
n_nodes = max(n_nodes_pred, n_nodes_true)
|
126 |
+
tree_pred = self.load_html_tree(pred)
|
127 |
+
tree_true = self.load_html_tree(true)
|
128 |
+
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
|
129 |
+
return 1.0 - (float(distance) / n_nodes)
|
130 |
+
else:
|
131 |
+
return 0.0
|
132 |
+
|
133 |
+
def batch_evaluate(self, pred_json, true_json):
|
134 |
+
''' Computes TEDS score between the prediction and the ground truth of
|
135 |
+
a batch of samples
|
136 |
+
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
137 |
+
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
138 |
+
@output: {'FILENAME': 'TEDS SCORE', ...}
|
139 |
+
'''
|
140 |
+
samples = true_json.keys()
|
141 |
+
if self.n_jobs == 1:
|
142 |
+
scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
143 |
+
else:
|
144 |
+
inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
145 |
+
scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
146 |
+
total_score_simple = 0
|
147 |
+
num_simple = 0
|
148 |
+
total_score_complex = 0
|
149 |
+
num_complex = 0
|
150 |
+
total_score = 0
|
151 |
+
num_total = 0
|
152 |
+
for filename,score in zip(samples, scores):
|
153 |
+
print(filename)
|
154 |
+
print(score)
|
155 |
+
print('')
|
156 |
+
if true_json[filename]['type'] == 'simple':
|
157 |
+
total_score_simple += score
|
158 |
+
num_simple += 1
|
159 |
+
elif true_json[filename]['type'] == 'complex':
|
160 |
+
total_score_complex += score
|
161 |
+
num_complex += 1
|
162 |
+
else:
|
163 |
+
raise ValueError('Unknown type: %s' % true_json[filename]['type'])
|
164 |
+
total_score += score
|
165 |
+
num_total += 1
|
166 |
+
if num_simple > 0:
|
167 |
+
avg_score_simple = total_score_simple / num_simple
|
168 |
+
else:
|
169 |
+
avg_score_simple = 0
|
170 |
+
if num_complex > 0:
|
171 |
+
avg_score_complex = total_score_complex / num_complex
|
172 |
+
else:
|
173 |
+
avg_score_complex = 0
|
174 |
+
avg_score = total_score / num_total
|
175 |
+
print({'simple': (num_simple,avg_score_simple), 'complex': (num_complex,avg_score_complex), 'total': (num_total,avg_score)})
|
176 |
+
|
177 |
+
def main():
|
178 |
+
parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
|
179 |
+
parser.add_argument(
|
180 |
+
"workspace",
|
181 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--gt_file",
|
185 |
+
help="Ground truth file",
|
186 |
+
)
|
187 |
+
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
|
188 |
+
args = parser.parse_args()
|
189 |
+
|
190 |
+
pred_data = {}
|
191 |
+
root_dir = os.path.join(args.workspace, "results")
|
192 |
+
for jsonl_file in os.listdir(root_dir):
|
193 |
+
if jsonl_file.endswith(".jsonl"):
|
194 |
+
with open(os.path.join(root_dir, jsonl_file), "r") as f:
|
195 |
+
for line in f:
|
196 |
+
data = json.loads(line)
|
197 |
+
pdf_path = os.path.basename(data['metadata']['Source-File'])
|
198 |
+
document_text = data['text']
|
199 |
+
pred_data[pdf_path] = str(markdown2.markdown(document_text,extras=["tables"]))
|
200 |
+
|
201 |
+
|
202 |
+
gt_data = {}
|
203 |
+
with open(args.gt_file, "r") as f:
|
204 |
+
for line in f:
|
205 |
+
data = json.loads(line)
|
206 |
+
gt_data[data['image_name']] = {'html':data['gt_table'], 'type':data['type']}
|
207 |
+
|
208 |
+
teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
|
209 |
+
teds.batch_evaluate(pred_data, gt_data)
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
main()
|
eval/gen_element_merge_detect_data.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
def main():
|
6 |
+
parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task")
|
7 |
+
parser.add_argument(
|
8 |
+
"workspace",
|
9 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
10 |
+
)
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
json_dir = os.path.join(args.workspace, 'jsons')
|
14 |
+
if not os.path.exists(json_dir):
|
15 |
+
os.makedirs(json_dir)
|
16 |
+
|
17 |
+
jsonl_file = os.path.join(args.workspace, "data.jsonl")
|
18 |
+
with open(jsonl_file, "r") as f:
|
19 |
+
for line in f:
|
20 |
+
data = json.loads(line)
|
21 |
+
pdf_name_1 = data['pdf_name_1'].split(".")[0]
|
22 |
+
pdf_name_2 = data['pdf_name_2'].split(".")[0]
|
23 |
+
|
24 |
+
pdf_name,page_1 = pdf_name_1.split('_')
|
25 |
+
pdf_name,page_2 = pdf_name_2.split('_')
|
26 |
+
|
27 |
+
json_name = os.path.join(json_dir, pdf_name + '_' + page_1 + '_' + page_2 + '.json')
|
28 |
+
data = {
|
29 |
+
"page_1": "\n\n".join(data['md_elem_list_1']),
|
30 |
+
"page_2": "\n\n".join(data['md_elem_list_2']),
|
31 |
+
}
|
32 |
+
with open(json_name, 'w') as f:
|
33 |
+
json.dump(data, f)
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
main()
|
eval/gen_html_table_merge_data.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
def main():
|
6 |
+
parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task")
|
7 |
+
parser.add_argument(
|
8 |
+
"workspace",
|
9 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
10 |
+
)
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
json_dir = os.path.join(args.workspace, 'jsons')
|
14 |
+
if not os.path.exists(json_dir):
|
15 |
+
os.makedirs(json_dir)
|
16 |
+
|
17 |
+
jsonl_file = os.path.join(args.workspace, 'data.jsonl')
|
18 |
+
with open(jsonl_file, "r") as f:
|
19 |
+
for line in f:
|
20 |
+
data = json.loads(line)
|
21 |
+
json_name = data['image_name'].split('.')[0] + '.json'
|
22 |
+
|
23 |
+
json_path = os.path.join(json_dir, json_name)
|
24 |
+
data = {
|
25 |
+
"table_1": data['table_fragment_1'],
|
26 |
+
"table_2": data['table_fragment_2'],
|
27 |
+
}
|
28 |
+
with open(json_path, 'w') as f:
|
29 |
+
json.dump(data, f)
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
main()
|
eval/parallel.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
3 |
+
|
4 |
+
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
|
5 |
+
"""
|
6 |
+
A parallel version of the map function with a progress bar.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
array (array-like): An array to iterate over.
|
10 |
+
function (function): A python function to apply to the elements of array
|
11 |
+
n_jobs (int, default=16): The number of cores to use
|
12 |
+
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
|
13 |
+
keyword arguments to function
|
14 |
+
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
|
15 |
+
Useful for catching bugs
|
16 |
+
Returns:
|
17 |
+
[function(array[0]), function(array[1]), ...]
|
18 |
+
"""
|
19 |
+
# We run the first few iterations serially to catch bugs
|
20 |
+
if front_num > 0:
|
21 |
+
front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
|
22 |
+
else:
|
23 |
+
front = []
|
24 |
+
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
|
25 |
+
if n_jobs == 1:
|
26 |
+
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
|
27 |
+
# Assemble the workers
|
28 |
+
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
|
29 |
+
# Pass the elements of array into function
|
30 |
+
if use_kwargs:
|
31 |
+
futures = [pool.submit(function, **a) for a in array[front_num:]]
|
32 |
+
else:
|
33 |
+
futures = [pool.submit(function, a) for a in array[front_num:]]
|
34 |
+
kwargs = {
|
35 |
+
'total': len(futures),
|
36 |
+
'unit': 'it',
|
37 |
+
'unit_scale': True,
|
38 |
+
'leave': True
|
39 |
+
}
|
40 |
+
# Print out the progress as tasks complete
|
41 |
+
for f in tqdm(as_completed(futures), **kwargs):
|
42 |
+
pass
|
43 |
+
out = []
|
44 |
+
# Get the results from the futures.
|
45 |
+
for i, future in tqdm(enumerate(futures)):
|
46 |
+
try:
|
47 |
+
out.append(future.result())
|
48 |
+
except Exception as e:
|
49 |
+
out.append(e)
|
50 |
+
return front + out
|
ocrflux/check.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
+
import logging
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
def check_poppler_version():
|
10 |
+
try:
|
11 |
+
result = subprocess.run(["pdftoppm", "-h"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
12 |
+
if result.returncode == 0 and result.stderr.startswith("pdftoppm"):
|
13 |
+
logger.info("pdftoppm is installed and working.")
|
14 |
+
else:
|
15 |
+
logger.error("pdftoppm is installed but returned an error.")
|
16 |
+
sys.exit(1)
|
17 |
+
except FileNotFoundError:
|
18 |
+
logger.error("pdftoppm is not installed.")
|
19 |
+
sys.exit(1)
|
20 |
+
|
21 |
+
def check_vllm_version():
|
22 |
+
if importlib.util.find_spec("vllm") is None:
|
23 |
+
logger.error("VLLM needs to be installed with a separate command in order to find all dependencies properly.")
|
24 |
+
sys.exit(1)
|
25 |
+
|
26 |
+
|
27 |
+
def check_torch_gpu_available(min_gpu_memory: int = 20 * 1024**3):
|
28 |
+
try:
|
29 |
+
import torch
|
30 |
+
except:
|
31 |
+
logger.error("Pytorch must be installed, visit https://pytorch.org/ for installation instructions")
|
32 |
+
raise
|
33 |
+
|
34 |
+
try:
|
35 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory
|
36 |
+
assert gpu_memory >= min_gpu_memory
|
37 |
+
except:
|
38 |
+
logger.error(f"Torch was not able to find a GPU with at least {min_gpu_memory // (1024 ** 3)} GB of RAM.")
|
39 |
+
raise
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
check_poppler_version()
|
44 |
+
check_vllm_version()
|
ocrflux/image_utils.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import io
|
4 |
+
from typing import List, Union
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def get_page_image(pdf_path, page_number, target_longest_image_dim=None, image_rotation=0):
|
9 |
+
if pdf_path.lower().endswith(".pdf"):
|
10 |
+
# Convert PDF page to PNG using pdftoppm
|
11 |
+
pdftoppm_result = subprocess.run(
|
12 |
+
[
|
13 |
+
"pdftoppm",
|
14 |
+
"-png",
|
15 |
+
"-f",
|
16 |
+
str(page_number),
|
17 |
+
"-l",
|
18 |
+
str(page_number),
|
19 |
+
"-r",
|
20 |
+
"72", # 72 pixels per point is the conversion factor
|
21 |
+
pdf_path,
|
22 |
+
],
|
23 |
+
timeout=120,
|
24 |
+
stdout=subprocess.PIPE,
|
25 |
+
stderr=subprocess.PIPE,
|
26 |
+
)
|
27 |
+
assert pdftoppm_result.returncode == 0, pdftoppm_result.stderr
|
28 |
+
image = Image.open(io.BytesIO(pdftoppm_result.stdout))
|
29 |
+
else:
|
30 |
+
image = Image.open(pdf_path)
|
31 |
+
if image_rotation != 0:
|
32 |
+
image = image.rotate(-image_rotation, expand=True)
|
33 |
+
if target_longest_image_dim is not None:
|
34 |
+
width, height = image.size
|
35 |
+
if width > height:
|
36 |
+
new_width = target_longest_image_dim
|
37 |
+
new_height = int(height * (target_longest_image_dim / width))
|
38 |
+
else:
|
39 |
+
new_height = target_longest_image_dim
|
40 |
+
new_width = int(width * (target_longest_image_dim / height))
|
41 |
+
image = image.resize((new_width, new_height))
|
42 |
+
return image
|
43 |
+
|
44 |
+
|
45 |
+
def is_image(file_path):
|
46 |
+
try:
|
47 |
+
Image.open(file_path)
|
48 |
+
return True
|
49 |
+
except:
|
50 |
+
return False
|
ocrflux/inference.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import copy
|
3 |
+
from PIL import Image
|
4 |
+
from pypdf import PdfReader
|
5 |
+
from vllm import LLM, SamplingParams
|
6 |
+
from ocrflux.image_utils import get_page_image
|
7 |
+
from ocrflux.table_format import table_matrix2html
|
8 |
+
from ocrflux.prompts import PageResponse, build_page_to_markdown_prompt, build_element_merge_detect_prompt, build_html_table_merge_prompt
|
9 |
+
|
10 |
+
def build_qwen2_5_vl_prompt(question):
|
11 |
+
return (
|
12 |
+
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
13 |
+
f"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
14 |
+
f"{question}<|im_end|>\n"
|
15 |
+
"<|im_start|>assistant\n"
|
16 |
+
)
|
17 |
+
|
18 |
+
def build_page_to_markdown_query(file_path: str, page_number: int, target_longest_image_dim: int = 1024, image_rotation: int = 0) -> dict:
|
19 |
+
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
20 |
+
image = get_page_image(file_path, page_number, target_longest_image_dim=target_longest_image_dim, image_rotation=image_rotation)
|
21 |
+
question = build_page_to_markdown_prompt()
|
22 |
+
prompt = build_qwen2_5_vl_prompt(question)
|
23 |
+
query = {
|
24 |
+
"prompt": prompt,
|
25 |
+
"multi_modal_data": {"image": image},
|
26 |
+
}
|
27 |
+
return query
|
28 |
+
|
29 |
+
def build_element_merge_detect_query(text_list_1,text_list_2) -> dict:
|
30 |
+
image = Image.new('RGB', (28, 28), color='black')
|
31 |
+
question = build_element_merge_detect_prompt(text_list_1,text_list_2)
|
32 |
+
prompt = build_qwen2_5_vl_prompt(question)
|
33 |
+
query = {
|
34 |
+
"prompt": prompt,
|
35 |
+
"multi_modal_data": {"image": image},
|
36 |
+
}
|
37 |
+
return query
|
38 |
+
|
39 |
+
def build_html_table_merge_query(text_1,text_2) -> dict:
|
40 |
+
image = Image.new('RGB', (28, 28), color='black')
|
41 |
+
question = build_html_table_merge_prompt(text_1,text_2)
|
42 |
+
prompt = build_qwen2_5_vl_prompt(question)
|
43 |
+
query = {
|
44 |
+
"prompt": prompt,
|
45 |
+
"multi_modal_data": {"image": image},
|
46 |
+
}
|
47 |
+
return query
|
48 |
+
|
49 |
+
def bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result):
|
50 |
+
page_to_markdown_keys = list(page_to_markdown_result.keys())
|
51 |
+
element_merge_detect_keys = list(element_merge_detect_result.keys())
|
52 |
+
html_table_merge_keys = list(html_table_merge_result.keys())
|
53 |
+
|
54 |
+
for page_1,page_2,elem_idx_1,elem_idx_2 in sorted(html_table_merge_keys,key=lambda x: -x[0]):
|
55 |
+
page_to_markdown_result[page_1][elem_idx_1] = html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)]
|
56 |
+
page_to_markdown_result[page_2][elem_idx_2] = ''
|
57 |
+
|
58 |
+
for page_1,page_2 in sorted(element_merge_detect_keys,key=lambda x: -x[0]):
|
59 |
+
for elem_idx_1,elem_idx_2 in element_merge_detect_result[(page_1,page_2)]:
|
60 |
+
if len(page_to_markdown_result[page_1][elem_idx_1]) == 0 or page_to_markdown_result[page_1][elem_idx_1][-1] == '-' or ('\u4e00' <= page_to_markdown_result[page_1][elem_idx_1][-1] <= '\u9fff'):
|
61 |
+
page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + '' + page_to_markdown_result[page_2][elem_idx_2]
|
62 |
+
else:
|
63 |
+
page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + ' ' + page_to_markdown_result[page_2][elem_idx_2]
|
64 |
+
page_to_markdown_result[page_2][elem_idx_2] = ''
|
65 |
+
|
66 |
+
document_text_list = []
|
67 |
+
for page in page_to_markdown_keys:
|
68 |
+
page_text_list = [s for s in page_to_markdown_result[page] if s]
|
69 |
+
document_text_list += page_text_list
|
70 |
+
return "\n\n".join(document_text_list)
|
71 |
+
|
72 |
+
def parse(llm,file_path,skip_cross_page_merge=False,max_page_retries=0):
|
73 |
+
sampling_params = SamplingParams(temperature=0.0,max_tokens=8192)
|
74 |
+
if file_path.lower().endswith(".pdf"):
|
75 |
+
try:
|
76 |
+
reader = PdfReader(file_path)
|
77 |
+
num_pages = reader.get_num_pages()
|
78 |
+
except:
|
79 |
+
return None
|
80 |
+
else:
|
81 |
+
num_pages = 1
|
82 |
+
|
83 |
+
try:
|
84 |
+
# Stage 1: Page to Markdown
|
85 |
+
page_to_markdown_query_list = [build_page_to_markdown_query(file_path,page_num) for page_num in range(1, num_pages + 1)]
|
86 |
+
responses = llm.generate(page_to_markdown_query_list, sampling_params=sampling_params)
|
87 |
+
results = [response.outputs[0].text for response in responses]
|
88 |
+
page_to_markdown_result = {}
|
89 |
+
retry_list = []
|
90 |
+
for i,result in enumerate(results):
|
91 |
+
try:
|
92 |
+
json_data = json.loads(result)
|
93 |
+
page_response = PageResponse(**json_data)
|
94 |
+
natural_text = page_response.natural_text
|
95 |
+
markdown_element_list = []
|
96 |
+
for text in natural_text.split('\n\n'):
|
97 |
+
if text.startswith("<Image>") and text.endswith("</Image>"):
|
98 |
+
pass
|
99 |
+
elif text.startswith("<table>") and text.endswith("</table>"):
|
100 |
+
try:
|
101 |
+
new_text = table_matrix2html(text)
|
102 |
+
except:
|
103 |
+
new_text = text.replace("<t>","").replace("<l>","").replace("<lt>","")
|
104 |
+
markdown_element_list.append(new_text)
|
105 |
+
else:
|
106 |
+
markdown_element_list.append(text)
|
107 |
+
page_to_markdown_result[i+1] = markdown_element_list
|
108 |
+
except:
|
109 |
+
retry_list.append(i)
|
110 |
+
|
111 |
+
attempt = 0
|
112 |
+
while len(retry_list) > 0 and attempt < max_page_retries:
|
113 |
+
retry_page_to_markdown_query_list = [build_page_to_markdown_query(file_path,page_num) for page_num in retry_list]
|
114 |
+
retry_sampling_params = SamplingParams(temperature=0.1*attempt, max_tokens=8192)
|
115 |
+
responses = llm.generate(retry_page_to_markdown_query_list, sampling_params=retry_sampling_params)
|
116 |
+
results = [response.outputs[0].text for response in responses]
|
117 |
+
next_retry_list = []
|
118 |
+
for i,result in zip(retry_list,results):
|
119 |
+
try:
|
120 |
+
json_data = json.loads(result)
|
121 |
+
page_response = PageResponse(**json_data)
|
122 |
+
natural_text = page_response.natural_text
|
123 |
+
markdown_element_list = []
|
124 |
+
for text in natural_text.split('\n\n'):
|
125 |
+
if text.startswith("<Image>") and text.endswith("</Image>"):
|
126 |
+
pass
|
127 |
+
elif text.startswith("<table>") and text.endswith("</table>"):
|
128 |
+
try:
|
129 |
+
new_text = table_matrix2html(text)
|
130 |
+
except:
|
131 |
+
new_text = text.replace("<t>","").replace("<l>","").replace("<lt>","")
|
132 |
+
markdown_element_list.append(new_text)
|
133 |
+
else:
|
134 |
+
markdown_element_list.append(text)
|
135 |
+
page_to_markdown_result[i+1] = markdown_element_list
|
136 |
+
except:
|
137 |
+
next_retry_list.append(i)
|
138 |
+
retry_list = next_retry_list
|
139 |
+
attempt += 1
|
140 |
+
|
141 |
+
page_texts = {}
|
142 |
+
fallback_pages = []
|
143 |
+
for page_number in range(1, num_pages+1):
|
144 |
+
if page_number not in page_to_markdown_result.keys():
|
145 |
+
fallback_pages.append(page_number-1)
|
146 |
+
else:
|
147 |
+
page_texts[str(page_number-1)] = "\n\n".join(page_to_markdown_result[page_number])
|
148 |
+
|
149 |
+
if skip_cross_page_merge:
|
150 |
+
document_text_list = []
|
151 |
+
for i in range(num_pages):
|
152 |
+
if i not in fallback_pages:
|
153 |
+
document_text_list.append(page_texts[str(i)])
|
154 |
+
document_text = "\n\n".join(document_text_list)
|
155 |
+
return {
|
156 |
+
"orig_path": file_path,
|
157 |
+
"num_pages": num_pages,
|
158 |
+
"document_text": document_text,
|
159 |
+
"page_texts": page_texts,
|
160 |
+
"fallback_pages": fallback_pages,
|
161 |
+
}
|
162 |
+
|
163 |
+
# Stage 2: Element Merge Detect
|
164 |
+
element_merge_detect_keys = []
|
165 |
+
element_merge_detect_query_list = []
|
166 |
+
for page_num in range(1,num_pages):
|
167 |
+
if page_num in page_to_markdown_result.keys() and page_num+1 in page_to_markdown_result.keys():
|
168 |
+
element_merge_detect_query_list.append(build_element_merge_detect_query(page_to_markdown_result[page_num],page_to_markdown_result[page_num+1]))
|
169 |
+
element_merge_detect_keys.append((page_num,page_num+1))
|
170 |
+
responses = llm.generate(element_merge_detect_query_list, sampling_params=sampling_params)
|
171 |
+
results = [response.outputs[0].text for response in responses]
|
172 |
+
element_merge_detect_result = {}
|
173 |
+
for key,result in zip(element_merge_detect_keys,results):
|
174 |
+
try:
|
175 |
+
element_merge_detect_result[key] = eval(result)
|
176 |
+
except:
|
177 |
+
pass
|
178 |
+
|
179 |
+
# Stage 3: HTML Table Merge
|
180 |
+
html_table_merge_keys = []
|
181 |
+
for key,result in element_merge_detect_result.items():
|
182 |
+
page_1,page_2 = key
|
183 |
+
for elem_idx_1,elem_idx_2 in result:
|
184 |
+
text_1 = page_to_markdown_result[page_1][elem_idx_1]
|
185 |
+
text_2 = page_to_markdown_result[page_2][elem_idx_2]
|
186 |
+
if text_1.startswith("<table>") and text_1.endswith("</table>") and text_2.startswith("<table>") and text_2.endswith("</table>"):
|
187 |
+
html_table_merge_keys.append((page_1,page_2,elem_idx_1,elem_idx_2))
|
188 |
+
|
189 |
+
html_table_merge_keys = sorted(html_table_merge_keys,key=lambda x: -x[0])
|
190 |
+
|
191 |
+
html_table_merge_result = {}
|
192 |
+
page_to_markdown_result_tmp = copy.deepcopy(page_to_markdown_result)
|
193 |
+
i = 0
|
194 |
+
while i < len(html_table_merge_keys):
|
195 |
+
tmp = set()
|
196 |
+
keys = []
|
197 |
+
while i < len(html_table_merge_keys):
|
198 |
+
page_1,page_2,elem_idx_1,elem_idx_2 = html_table_merge_keys[i]
|
199 |
+
if (page_2,elem_idx_2) in tmp:
|
200 |
+
break
|
201 |
+
tmp.add((page_1,elem_idx_1))
|
202 |
+
keys.append((page_1,page_2,elem_idx_1,elem_idx_2))
|
203 |
+
i += 1
|
204 |
+
|
205 |
+
html_table_merge_query_list = [build_html_table_merge_query(page_to_markdown_result_tmp[page_1][elem_idx_1],page_to_markdown_result_tmp[page_2][elem_idx_2]) for page_1,page_2,elem_idx_1,elem_idx_2 in keys]
|
206 |
+
responses = llm.generate(html_table_merge_query_list, sampling_params=sampling_params)
|
207 |
+
results = [response.outputs[0].text for response in responses]
|
208 |
+
for key,result in zip(keys,results):
|
209 |
+
if result.startswith("<table>") and result.endswith("</table>"):
|
210 |
+
html_table_merge_result[key] = result
|
211 |
+
page_to_markdown_result_tmp[page_1][elem_idx_1] = result
|
212 |
+
|
213 |
+
document_text = bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result)
|
214 |
+
return {
|
215 |
+
"orig_path": file_path,
|
216 |
+
"num_pages": num_pages,
|
217 |
+
"document_text": document_text,
|
218 |
+
"page_texts": page_texts,
|
219 |
+
"fallback_pages": fallback_pages,
|
220 |
+
}
|
221 |
+
except:
|
222 |
+
return None
|
223 |
+
|
224 |
+
|
225 |
+
if __name__ == '__main__':
|
226 |
+
file_path = 'test.pdf'
|
227 |
+
llm = LLM(model="ChatDOC/OCRFlux-3B",gpu_memory_utilization=0.8,max_model_len=8192)
|
228 |
+
result = parse(llm,file_path,max_page_retries=4)
|
229 |
+
if result != None:
|
230 |
+
document_markdown = result['document_text']
|
231 |
+
print(document_markdown)
|
232 |
+
with open('test.md','w') as f:
|
233 |
+
f.write(document_markdown)
|
234 |
+
else:
|
235 |
+
print("Parse failed")
|
236 |
+
|
237 |
+
|
ocrflux/jsonl_to_markdown.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
def main():
|
5 |
+
parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
|
6 |
+
parser.add_argument(
|
7 |
+
"workspace",
|
8 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
9 |
+
)
|
10 |
+
parser.add_argument("--show_page_result", action="store_true", help="Whether to show the markdown of each page")
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
src_dir = os.path.join(args.workspace, "results")
|
14 |
+
tgt_dir = os.path.join(args.workspace, "markdowns")
|
15 |
+
if not os.path.exists(tgt_dir):
|
16 |
+
os.makedirs(tgt_dir)
|
17 |
+
for jsonl_file in os.listdir(src_dir):
|
18 |
+
if jsonl_file.endswith(".jsonl"):
|
19 |
+
with open(os.path.join(src_dir, jsonl_file), "r") as f:
|
20 |
+
for line in f:
|
21 |
+
data = json.loads(line)
|
22 |
+
markdown_text = data['document_text']
|
23 |
+
file_name = os.path.basename(data['orig_path']).split(".")[0]
|
24 |
+
file_dir = os.path.join(tgt_dir, file_name)
|
25 |
+
if not os.path.exists(file_dir):
|
26 |
+
os.makedirs(file_dir)
|
27 |
+
with open(os.path.join(file_dir, file_name+".md"), "w") as f:
|
28 |
+
f.write(markdown_text)
|
29 |
+
if args.show_page_result:
|
30 |
+
page_texts = data["page_texts"]
|
31 |
+
for page_num in page_texts.keys():
|
32 |
+
page_text = page_texts[page_num]
|
33 |
+
with open(os.path.join(file_dir, file_name+"_"+str(page_num)+".md"), "w") as f:
|
34 |
+
f.write(page_text)
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
main()
|
ocrflux/metrics.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import time
|
3 |
+
from collections import defaultdict, deque
|
4 |
+
from typing import Any, Deque, Dict, List, Set
|
5 |
+
|
6 |
+
|
7 |
+
class MetricsKeeper:
|
8 |
+
def __init__(self, window=60 * 5):
|
9 |
+
"""
|
10 |
+
Initializes the MetricsKeeper.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
window (int): Time window in seconds for recent metrics. Defaults to 5 minutes.
|
14 |
+
"""
|
15 |
+
self.window = window # Time window in seconds
|
16 |
+
self.start_time = time.time() # Timestamp when MetricsKeeper was created
|
17 |
+
self.total_metrics = defaultdict(int) # Cumulative metrics since start
|
18 |
+
self.window_metrics: Deque[Any] = deque() # Deque to store (timestamp, metrics_dict)
|
19 |
+
self.window_sum = defaultdict(int) # Sum of metrics within the window
|
20 |
+
|
21 |
+
def add_metrics(self, **kwargs):
|
22 |
+
"""
|
23 |
+
Adds metrics to the keeper.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
**kwargs: Arbitrary keyword arguments representing metric names and their values.
|
27 |
+
"""
|
28 |
+
current_time = time.time()
|
29 |
+
# Update cumulative metrics
|
30 |
+
for key, value in kwargs.items():
|
31 |
+
self.total_metrics[key] += value
|
32 |
+
|
33 |
+
# Append current metrics with timestamp to the deque
|
34 |
+
self.window_metrics.append((current_time, kwargs))
|
35 |
+
|
36 |
+
# Update window sums
|
37 |
+
for key, value in kwargs.items():
|
38 |
+
self.window_sum[key] += value
|
39 |
+
|
40 |
+
# Remove metrics that are outside the time window
|
41 |
+
while self.window_metrics and self.window_metrics[0][0] < current_time - self.window:
|
42 |
+
old_time, old_metrics = self.window_metrics.popleft()
|
43 |
+
for key, value in old_metrics.items():
|
44 |
+
self.window_sum[key] -= value
|
45 |
+
if self.window_sum[key] <= 0:
|
46 |
+
del self.window_sum[key] # Clean up to prevent negative counts
|
47 |
+
|
48 |
+
def __str__(self):
|
49 |
+
"""
|
50 |
+
Returns a formatted string of metrics showing tokens/sec since start and within the window.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
str: Formatted metrics string as a table.
|
54 |
+
"""
|
55 |
+
current_time = time.time()
|
56 |
+
elapsed_time = current_time - self.start_time
|
57 |
+
window_time = min(self.window, elapsed_time) if elapsed_time > 0 else 1 # Prevent division by zero
|
58 |
+
|
59 |
+
# Header
|
60 |
+
header = f"{'Metric Name':<30} {'Lifetime (tokens/sec)':>25} {'Recently (tokens/sec)':>25}"
|
61 |
+
separator = "-" * len(header)
|
62 |
+
lines = [header, separator]
|
63 |
+
|
64 |
+
# Sort metrics alphabetically for consistency
|
65 |
+
for key in sorted(self.total_metrics.keys()):
|
66 |
+
total = self.total_metrics[key]
|
67 |
+
window = self.window_sum.get(key, 0)
|
68 |
+
total_rate = total / elapsed_time if elapsed_time > 0 else 0
|
69 |
+
window_rate = window / window_time if window_time > 0 else 0
|
70 |
+
line = f"{key:<20} {total_rate:>25.2f} {window_rate:>25.2f}"
|
71 |
+
lines.append(line)
|
72 |
+
|
73 |
+
return "\n".join(lines)
|
74 |
+
|
75 |
+
|
76 |
+
class WorkerTracker:
|
77 |
+
def __init__(self):
|
78 |
+
"""
|
79 |
+
Initializes the WorkerTracker with a default dictionary.
|
80 |
+
Each worker ID maps to another dictionary that holds counts for each state.
|
81 |
+
"""
|
82 |
+
# Mapping from worker_id to a dictionary of state counts
|
83 |
+
self.worker_status: Dict[int, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
84 |
+
self.lock = asyncio.Lock()
|
85 |
+
|
86 |
+
async def clear_work(self, worker_id: int):
|
87 |
+
async with self.lock:
|
88 |
+
self.worker_status[worker_id].clear()
|
89 |
+
|
90 |
+
async def track_work(self, worker_id: int, work_item_id: str, state: str):
|
91 |
+
"""
|
92 |
+
Update the state count for a specific worker.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
worker_id (int): The ID of the worker.
|
96 |
+
work_item_id (str): The unique identifier of the work item (unused in this implementation).
|
97 |
+
state (str): The state to increment for the work item.
|
98 |
+
"""
|
99 |
+
async with self.lock:
|
100 |
+
self.worker_status[worker_id][state] += 1
|
101 |
+
|
102 |
+
async def get_status_table(self) -> str:
|
103 |
+
"""
|
104 |
+
Generate a formatted table of the current status of all workers.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
str: A string representation of the workers' statuses.
|
108 |
+
"""
|
109 |
+
async with self.lock:
|
110 |
+
# Determine all unique states across all workers
|
111 |
+
all_states: Set[str] = set()
|
112 |
+
for states in self.worker_status.values():
|
113 |
+
all_states.update(states.keys())
|
114 |
+
sorted_states: List[str] = sorted(all_states)
|
115 |
+
|
116 |
+
headers = ["Worker ID"] + sorted_states # type: ignore
|
117 |
+
rows = []
|
118 |
+
for worker_id, states in sorted(self.worker_status.items()):
|
119 |
+
row = [str(worker_id)]
|
120 |
+
for state in sorted_states:
|
121 |
+
count = states.get(state, 0)
|
122 |
+
row.append(str(count))
|
123 |
+
rows.append(row)
|
124 |
+
|
125 |
+
# Calculate column widths
|
126 |
+
col_widths = [len(header) for header in headers]
|
127 |
+
for row in rows:
|
128 |
+
for idx, cell in enumerate(row):
|
129 |
+
col_widths[idx] = max(col_widths[idx], len(cell))
|
130 |
+
|
131 |
+
# Create the table header
|
132 |
+
header_line = " | ".join(header.ljust(col_widths[idx]) for idx, header in enumerate(headers))
|
133 |
+
separator = "-+-".join("-" * col_widths[idx] for idx in range(len(headers)))
|
134 |
+
|
135 |
+
# Create the table rows
|
136 |
+
row_lines = [" | ".join(cell.ljust(col_widths[idx]) for idx, cell in enumerate(row)) for row in rows]
|
137 |
+
|
138 |
+
# Combine all parts
|
139 |
+
table = "\n".join([header_line, separator] + row_lines)
|
140 |
+
return table
|
141 |
+
|
142 |
+
def __str__(self):
|
143 |
+
"""
|
144 |
+
String representation is not directly supported.
|
145 |
+
Use 'await get_status_table()' to retrieve the status table.
|
146 |
+
"""
|
147 |
+
raise NotImplementedError("Use 'await get_status_table()' to get the status table.")
|
ocrflux/pipeline.py
ADDED
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import atexit
|
4 |
+
import base64
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import shutil
|
8 |
+
import os
|
9 |
+
import copy
|
10 |
+
import random
|
11 |
+
import re
|
12 |
+
import sys
|
13 |
+
import time
|
14 |
+
from concurrent.futures.process import BrokenProcessPool
|
15 |
+
from io import BytesIO
|
16 |
+
from urllib.parse import urlparse
|
17 |
+
|
18 |
+
import httpx
|
19 |
+
from huggingface_hub import snapshot_download
|
20 |
+
from PIL import Image
|
21 |
+
from pypdf import PdfReader
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from ocrflux.check import (
|
25 |
+
check_poppler_version,
|
26 |
+
check_vllm_version,
|
27 |
+
check_torch_gpu_available,
|
28 |
+
)
|
29 |
+
from ocrflux.image_utils import get_page_image, is_image
|
30 |
+
from ocrflux.table_format import trans_markdown_text
|
31 |
+
from ocrflux.metrics import MetricsKeeper, WorkerTracker
|
32 |
+
from ocrflux.prompts import PageResponse, build_page_to_markdown_prompt, build_element_merge_detect_prompt, build_html_table_merge_prompt
|
33 |
+
from ocrflux.work_queue import LocalWorkQueue, WorkQueue
|
34 |
+
|
35 |
+
# Initialize logger
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
logger.setLevel(logging.DEBUG)
|
38 |
+
logger.propagate = False
|
39 |
+
|
40 |
+
vllm_logger = logging.getLogger("vllm")
|
41 |
+
vllm_logger.propagate = False
|
42 |
+
|
43 |
+
file_handler = logging.FileHandler("OCRFlux-debug.log", mode="a")
|
44 |
+
file_handler.setLevel(logging.DEBUG)
|
45 |
+
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
46 |
+
|
47 |
+
console_handler = logging.StreamHandler()
|
48 |
+
console_handler.setLevel(logging.INFO)
|
49 |
+
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
|
50 |
+
|
51 |
+
# Add handlers to the logger
|
52 |
+
logger.addHandler(file_handler)
|
53 |
+
logger.addHandler(console_handler)
|
54 |
+
vllm_logger.addHandler(file_handler)
|
55 |
+
|
56 |
+
# Quiet logs from pypdf
|
57 |
+
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
58 |
+
|
59 |
+
# Global variables for token statistics
|
60 |
+
metrics = MetricsKeeper(window=60 * 5)
|
61 |
+
tracker = WorkerTracker()
|
62 |
+
|
63 |
+
def build_page_to_markdown_query(args, pdf_path: str, page_number: int, target_longest_image_dim: int, image_rotation: int = 0) -> dict:
|
64 |
+
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
65 |
+
|
66 |
+
image = get_page_image(pdf_path, page_number, target_longest_image_dim=target_longest_image_dim, image_rotation=image_rotation)
|
67 |
+
buffered = BytesIO()
|
68 |
+
image.save(buffered, format="PNG")
|
69 |
+
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
70 |
+
|
71 |
+
return {
|
72 |
+
"model": args.model,
|
73 |
+
"messages": [
|
74 |
+
{
|
75 |
+
"role": "user",
|
76 |
+
"content": [
|
77 |
+
{"type": "text", "text": build_page_to_markdown_prompt()},
|
78 |
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
|
79 |
+
],
|
80 |
+
}
|
81 |
+
],
|
82 |
+
"temperature": 0.0,
|
83 |
+
}
|
84 |
+
|
85 |
+
def build_element_merge_detect_query(args,text_list_1,text_list_2) -> dict:
|
86 |
+
image = Image.new('RGB', (28, 28), color='black')
|
87 |
+
|
88 |
+
buffered = BytesIO()
|
89 |
+
image.save(buffered, format="PNG")
|
90 |
+
|
91 |
+
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
92 |
+
|
93 |
+
return {
|
94 |
+
"model": args.model,
|
95 |
+
"messages": [
|
96 |
+
{
|
97 |
+
"role": "user",
|
98 |
+
"content": [
|
99 |
+
{"type": "text", "text": build_element_merge_detect_prompt(text_list_1,text_list_2)},
|
100 |
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
|
101 |
+
],
|
102 |
+
}
|
103 |
+
],
|
104 |
+
"temperature": 0.0,
|
105 |
+
}
|
106 |
+
|
107 |
+
def build_html_table_merge_query(args,text_1,text_2) -> dict:
|
108 |
+
image = Image.new('RGB', (28, 28), color='black')
|
109 |
+
|
110 |
+
buffered = BytesIO()
|
111 |
+
image.save(buffered, format="PNG")
|
112 |
+
|
113 |
+
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
114 |
+
|
115 |
+
return {
|
116 |
+
"model": args.model,
|
117 |
+
"messages": [
|
118 |
+
{
|
119 |
+
"role": "user",
|
120 |
+
"content": [
|
121 |
+
{"type": "text", "text": build_html_table_merge_prompt(text_1,text_2)},
|
122 |
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
|
123 |
+
],
|
124 |
+
}
|
125 |
+
],
|
126 |
+
"temperature": 0.0,
|
127 |
+
}
|
128 |
+
|
129 |
+
# Manual simple implementation of HTTP Post
|
130 |
+
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
|
131 |
+
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
|
132 |
+
# that at the scale of 100M+ requests, that they deadlock in different strange ways
|
133 |
+
async def apost(url, json_data):
|
134 |
+
parsed_url = urlparse(url)
|
135 |
+
host = parsed_url.hostname
|
136 |
+
port = parsed_url.port or 80
|
137 |
+
path = parsed_url.path or "/"
|
138 |
+
|
139 |
+
writer = None
|
140 |
+
try:
|
141 |
+
reader, writer = await asyncio.open_connection(host, port)
|
142 |
+
|
143 |
+
json_payload = json.dumps(json_data)
|
144 |
+
request = (
|
145 |
+
f"POST {path} HTTP/1.1\r\n"
|
146 |
+
f"Host: {host}\r\n"
|
147 |
+
f"Content-Type: application/json\r\n"
|
148 |
+
f"Content-Length: {len(json_payload)}\r\n"
|
149 |
+
f"Connection: close\r\n\r\n"
|
150 |
+
f"{json_payload}"
|
151 |
+
)
|
152 |
+
writer.write(request.encode())
|
153 |
+
await writer.drain()
|
154 |
+
|
155 |
+
# Read status line
|
156 |
+
status_line = await reader.readline()
|
157 |
+
if not status_line:
|
158 |
+
raise ConnectionError("No response from server")
|
159 |
+
status_parts = status_line.decode().strip().split(" ", 2)
|
160 |
+
if len(status_parts) < 2:
|
161 |
+
raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
|
162 |
+
status_code = int(status_parts[1])
|
163 |
+
|
164 |
+
# Read headers
|
165 |
+
headers = {}
|
166 |
+
while True:
|
167 |
+
line = await reader.readline()
|
168 |
+
if line in (b"\r\n", b"\n", b""):
|
169 |
+
break
|
170 |
+
key, _, value = line.decode().partition(":")
|
171 |
+
headers[key.strip().lower()] = value.strip()
|
172 |
+
|
173 |
+
# Read response body
|
174 |
+
if "content-length" in headers:
|
175 |
+
body_length = int(headers["content-length"])
|
176 |
+
response_body = await reader.readexactly(body_length)
|
177 |
+
else:
|
178 |
+
raise ConnectionError("Anything other than fixed content length responses are not implemented yet")
|
179 |
+
|
180 |
+
return status_code, response_body
|
181 |
+
except Exception as e:
|
182 |
+
# Pass through errors
|
183 |
+
raise e
|
184 |
+
finally:
|
185 |
+
# But just make sure to close the socket on your way out
|
186 |
+
if writer is not None:
|
187 |
+
try:
|
188 |
+
writer.close()
|
189 |
+
await writer.wait_closed()
|
190 |
+
except:
|
191 |
+
pass
|
192 |
+
|
193 |
+
async def process_task(args, worker_id, task_name, task_args):
|
194 |
+
COMPLETION_URL = f"http://localhost:{args.port}/v1/chat/completions"
|
195 |
+
MAX_RETRIES = args.max_page_retries
|
196 |
+
TEMPERATURE_BY_ATTEMPT = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
197 |
+
exponential_backoffs = 0
|
198 |
+
local_image_rotation = 0
|
199 |
+
attempt = 0
|
200 |
+
await tracker.track_work(worker_id, f"{worker_id}", "started")
|
201 |
+
while attempt < MAX_RETRIES:
|
202 |
+
if task_name == 'page_to_markdown':
|
203 |
+
pdf_path,page_number = task_args
|
204 |
+
query = build_page_to_markdown_query(args, pdf_path, page_number, args.target_longest_image_dim, image_rotation=local_image_rotation)
|
205 |
+
elif task_name == 'element_merge_detect':
|
206 |
+
text_list_1,text_list_2 = task_args
|
207 |
+
query = build_element_merge_detect_query(args, text_list_1, text_list_2)
|
208 |
+
elif task_name == 'html_table_merge':
|
209 |
+
table_1,table_2 = task_args
|
210 |
+
query = build_html_table_merge_query(args, table_1, table_2)
|
211 |
+
query["temperature"] = TEMPERATURE_BY_ATTEMPT[
|
212 |
+
min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
|
213 |
+
] # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
|
214 |
+
|
215 |
+
try:
|
216 |
+
status_code, response_body = await apost(COMPLETION_URL, json_data=query)
|
217 |
+
|
218 |
+
if status_code == 400:
|
219 |
+
raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response")
|
220 |
+
elif status_code == 500:
|
221 |
+
raise ValueError(f"Got InternalServerError from server: {response_body}, skipping this response")
|
222 |
+
elif status_code != 200:
|
223 |
+
raise ValueError(f"Error http status {status_code}")
|
224 |
+
|
225 |
+
base_response_data = json.loads(response_body)
|
226 |
+
|
227 |
+
metrics.add_metrics(
|
228 |
+
vllm_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
|
229 |
+
vllm_output_tokens=base_response_data["usage"].get("completion_tokens", 0),
|
230 |
+
)
|
231 |
+
|
232 |
+
response_content = base_response_data["choices"][0]["message"]["content"]
|
233 |
+
if task_name == 'page_to_markdown':
|
234 |
+
model_response_json = json.loads(response_content)
|
235 |
+
page_response = PageResponse(**model_response_json)
|
236 |
+
if not page_response.is_rotation_valid and attempt < MAX_RETRIES - 1:
|
237 |
+
local_image_rotation = page_response.rotation_correction
|
238 |
+
raise ValueError(f"invalid_page rotation")
|
239 |
+
try:
|
240 |
+
return_data = trans_markdown_text(page_response.natural_text,"matrix2html")
|
241 |
+
except:
|
242 |
+
if attempt < MAX_RETRIES - 1:
|
243 |
+
raise
|
244 |
+
else:
|
245 |
+
return_data = page_response.natural_text.replace("<t>","").replace("<l>","").replace("<lt>","")
|
246 |
+
|
247 |
+
elif task_name == 'element_merge_detect':
|
248 |
+
pattern = r"\((\d+), (\d+)\)"
|
249 |
+
matches = re.findall(pattern, response_content)
|
250 |
+
return_data = [(int(x), int(y)) for x, y in matches]
|
251 |
+
elif task_name == 'html_table_merge':
|
252 |
+
if not (response_content.startswith("<table>") and response_content.endswith("</table>")):
|
253 |
+
raise ValueError("Response is not a table")
|
254 |
+
return_data = response_content
|
255 |
+
else:
|
256 |
+
raise ValueError(f"Unknown task_name {task_name}")
|
257 |
+
|
258 |
+
await tracker.track_work(worker_id, f"{worker_id}", "finished")
|
259 |
+
return return_data
|
260 |
+
|
261 |
+
except (ConnectionError, OSError, asyncio.TimeoutError) as e:
|
262 |
+
logger.warning(f"Client error on attempt {attempt} for {worker_id}: {type(e)} {e}")
|
263 |
+
|
264 |
+
# Now we want to do exponential backoff, and not count this as an actual page retry
|
265 |
+
# Page retrys are supposed to be for fixing bad results from the model, but actual requests to vllm
|
266 |
+
# are supposed to work. Probably this means that the server is just restarting
|
267 |
+
sleep_delay = 10 * (2**exponential_backoffs)
|
268 |
+
exponential_backoffs += 1
|
269 |
+
logger.info(f"Sleeping for {sleep_delay} seconds on {worker_id} to allow server restart")
|
270 |
+
await asyncio.sleep(sleep_delay)
|
271 |
+
except asyncio.CancelledError:
|
272 |
+
logger.info(f"Process {worker_id} cancelled")
|
273 |
+
await tracker.track_work(worker_id, f"{worker_id}", "cancelled")
|
274 |
+
raise
|
275 |
+
except json.JSONDecodeError as e:
|
276 |
+
logger.warning(f"JSON decode error on attempt {attempt} for {worker_id}: {e}")
|
277 |
+
attempt += 1
|
278 |
+
except ValueError as e:
|
279 |
+
logger.warning(f"ValueError on attempt {attempt} for {worker_id}: {type(e)} - {e}")
|
280 |
+
attempt += 1
|
281 |
+
except Exception as e:
|
282 |
+
logger.exception(f"Unexpected error on attempt {attempt} for {worker_id}: {type(e)} - {e}")
|
283 |
+
attempt += 1
|
284 |
+
|
285 |
+
logger.error(f"Failed to process {worker_id} after {MAX_RETRIES} attempts.")
|
286 |
+
await tracker.track_work(worker_id, f"{worker_id}", "errored")
|
287 |
+
|
288 |
+
return None
|
289 |
+
|
290 |
+
def postprocess_markdown_text(args, response_text, pdf_path, page_number):
|
291 |
+
text_list = response_text.split("\n\n")
|
292 |
+
new_text_list = []
|
293 |
+
for text in text_list:
|
294 |
+
if text.startswith("<Image>") and text.endswith("</Image>"):
|
295 |
+
pass
|
296 |
+
else:
|
297 |
+
new_text_list.append(text)
|
298 |
+
return "\n\n".join(new_text_list)
|
299 |
+
|
300 |
+
def bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result):
|
301 |
+
page_to_markdown_keys = list(page_to_markdown_result.keys())
|
302 |
+
element_merge_detect_keys = list(element_merge_detect_result.keys())
|
303 |
+
html_table_merge_keys = list(html_table_merge_result.keys())
|
304 |
+
|
305 |
+
for page_1,page_2,elem_idx_1,elem_idx_2 in sorted(html_table_merge_keys,key=lambda x: -x[0]):
|
306 |
+
page_to_markdown_result[page_1][elem_idx_1] = html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)]
|
307 |
+
page_to_markdown_result[page_2][elem_idx_2] = ''
|
308 |
+
|
309 |
+
for page_1,page_2 in sorted(element_merge_detect_keys,key=lambda x: -x[0]):
|
310 |
+
for elem_idx_1,elem_idx_2 in element_merge_detect_result[(page_1,page_2)]:
|
311 |
+
if len(page_to_markdown_result[page_1][elem_idx_1]) == 0 or page_to_markdown_result[page_1][elem_idx_1][-1] == '-' or ('\u4e00' <= page_to_markdown_result[page_1][elem_idx_1][-1] <= '\u9fff'):
|
312 |
+
page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + '' + page_to_markdown_result[page_2][elem_idx_2]
|
313 |
+
else:
|
314 |
+
page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + ' ' + page_to_markdown_result[page_2][elem_idx_2]
|
315 |
+
page_to_markdown_result[page_2][elem_idx_2] = ''
|
316 |
+
|
317 |
+
document_text_list = []
|
318 |
+
for page in page_to_markdown_keys:
|
319 |
+
page_text_list = [s for s in page_to_markdown_result[page] if s]
|
320 |
+
document_text_list += page_text_list
|
321 |
+
return "\n\n".join(document_text_list)
|
322 |
+
|
323 |
+
async def process_pdf(args, worker_id: int, pdf_path: str):
|
324 |
+
logger.info(f"Start process_pdf for {pdf_path}")
|
325 |
+
if pdf_path.lower().endswith(".pdf"):
|
326 |
+
try:
|
327 |
+
reader = PdfReader(pdf_path)
|
328 |
+
num_pages = reader.get_num_pages()
|
329 |
+
except:
|
330 |
+
logger.exception(f"Could not count number of pages for {pdf_path}, aborting document")
|
331 |
+
return None
|
332 |
+
else:
|
333 |
+
num_pages = 1
|
334 |
+
|
335 |
+
logger.info(f"Got {num_pages} pages to do for {pdf_path} in worker {worker_id}")
|
336 |
+
|
337 |
+
try:
|
338 |
+
tasks = []
|
339 |
+
results = []
|
340 |
+
async with asyncio.TaskGroup() as tg:
|
341 |
+
for page_num in range(1, num_pages + 1):
|
342 |
+
task = tg.create_task(process_task(args, worker_id, task_name='page_to_markdown', task_args=(pdf_path,page_num)))
|
343 |
+
tasks.append(task)
|
344 |
+
|
345 |
+
results = [task.result() for task in tasks]
|
346 |
+
|
347 |
+
fallback_pages = []
|
348 |
+
page_to_markdown_result = {}
|
349 |
+
page_pairs = []
|
350 |
+
for i,result in enumerate(results):
|
351 |
+
if result != None:
|
352 |
+
page_number = i+1
|
353 |
+
page_to_markdown_result[i+1] = postprocess_markdown_text(args,result,pdf_path,page_number).split("\n\n")
|
354 |
+
if page_number-1 in page_to_markdown_result.keys():
|
355 |
+
page_pairs.append((page_number-1,page_number))
|
356 |
+
else:
|
357 |
+
fallback_pages.append(i)
|
358 |
+
|
359 |
+
num_fallback_pages = len(fallback_pages)
|
360 |
+
|
361 |
+
if num_fallback_pages / num_pages > args.max_page_error_rate:
|
362 |
+
logger.error(
|
363 |
+
f"Document {pdf_path} has {num_fallback_pages} fallback pages out of {num_pages} exceeding max_page_error_rate of {args.max_page_error_rate}, discarding document."
|
364 |
+
)
|
365 |
+
return None
|
366 |
+
elif num_fallback_pages > 0:
|
367 |
+
logger.warning(
|
368 |
+
f"Document {pdf_path} processed with {num_fallback_pages} fallback pages out of {num_pages}."
|
369 |
+
)
|
370 |
+
|
371 |
+
if args.skip_cross_page_merge:
|
372 |
+
page_texts = {}
|
373 |
+
document_text_list = []
|
374 |
+
sorted_page_keys = sorted(list(page_to_markdown_result.keys()))
|
375 |
+
for page_number in sorted_page_keys:
|
376 |
+
page_texts[str(page_number-1)] = "\n\n".join(page_to_markdown_result[page_number])
|
377 |
+
document_text_list.append(page_texts[str(page_number-1)])
|
378 |
+
document_text = "\n\n".join(document_text_list)
|
379 |
+
return {
|
380 |
+
"orig_path": pdf_path,
|
381 |
+
"num_pages": num_pages,
|
382 |
+
"document_text": document_text,
|
383 |
+
"page_texts": page_texts,
|
384 |
+
"fallback_pages": fallback_pages,
|
385 |
+
}
|
386 |
+
|
387 |
+
tasks = []
|
388 |
+
results = []
|
389 |
+
async with asyncio.TaskGroup() as tg:
|
390 |
+
for page_1,page_2 in page_pairs:
|
391 |
+
task = tg.create_task(process_task(args, worker_id, task_name='element_merge_detect', task_args=(page_to_markdown_result[page_1], page_to_markdown_result[page_2])))
|
392 |
+
tasks.append(task)
|
393 |
+
results = [task.result() for task in tasks]
|
394 |
+
|
395 |
+
element_merge_detect_result = {}
|
396 |
+
table_pairs = []
|
397 |
+
for page_pair,result in zip(page_pairs,results):
|
398 |
+
if result != None:
|
399 |
+
page_1,page_2 = page_pair
|
400 |
+
element_merge_detect_result[(page_1,page_2)] = result
|
401 |
+
for elem_idx_1,elem_idx_2 in result:
|
402 |
+
text_1 = page_to_markdown_result[page_1][elem_idx_1]
|
403 |
+
text_2 = page_to_markdown_result[page_2][elem_idx_2]
|
404 |
+
if text_1.startswith("<table>") and text_1.endswith("</table>") and text_2.startswith("<table>") and text_2.endswith("</table>"):
|
405 |
+
table_pairs.append((page_1,page_2,elem_idx_1,elem_idx_2))
|
406 |
+
|
407 |
+
tmp_page_to_markdown_result = copy.deepcopy(page_to_markdown_result)
|
408 |
+
table_pairs = sorted(table_pairs,key=lambda x: -x[0])
|
409 |
+
html_table_merge_result = {}
|
410 |
+
i = 0
|
411 |
+
while i < len(table_pairs):
|
412 |
+
async with asyncio.TaskGroup() as tg:
|
413 |
+
tasks = []
|
414 |
+
ids_1 = []
|
415 |
+
ids_2 = []
|
416 |
+
page_1,page_2,elem_idx_1,elem_idx_2 = table_pairs[i]
|
417 |
+
task = tg.create_task(process_task(args, worker_id, task_name='html_table_merge', task_args=(tmp_page_to_markdown_result[page_1][elem_idx_1], tmp_page_to_markdown_result[page_2][elem_idx_2])))
|
418 |
+
tasks.append(task)
|
419 |
+
ids_1.append((page_1,elem_idx_1))
|
420 |
+
ids_2.append((page_2,elem_idx_2))
|
421 |
+
j = i + 1
|
422 |
+
while j < len(table_pairs):
|
423 |
+
page_1,page_2,elem_idx_1,elem_idx_2 = table_pairs[j]
|
424 |
+
if (page_2, elem_idx_2) not in ids_1:
|
425 |
+
task = tg.create_task(process_task(args, worker_id, task_name='html_table_merge', task_args=(tmp_page_to_markdown_result[page_1][elem_idx_1], tmp_page_to_markdown_result[page_2][elem_idx_2])))
|
426 |
+
tasks.append(task)
|
427 |
+
ids_1.append((page_1,elem_idx_1))
|
428 |
+
ids_2.append((page_2,elem_idx_2))
|
429 |
+
j = j + 1
|
430 |
+
else:
|
431 |
+
break
|
432 |
+
|
433 |
+
results = [task.result() for task in tasks]
|
434 |
+
|
435 |
+
for k,result in enumerate(results):
|
436 |
+
page_1,elem_idx_1 = ids_1[k]
|
437 |
+
page_2,elem_idx_2 = ids_2[k]
|
438 |
+
if result != None:
|
439 |
+
html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)] = result
|
440 |
+
tmp_page_to_markdown_result[page_1][elem_idx_1] = html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)]
|
441 |
+
i = j
|
442 |
+
|
443 |
+
page_texts = {}
|
444 |
+
for page_number in page_to_markdown_result.keys():
|
445 |
+
page_texts[str(page_number-1)] = "\n\n".join(page_to_markdown_result[page_number])
|
446 |
+
|
447 |
+
document_text = bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result)
|
448 |
+
|
449 |
+
return {
|
450 |
+
"orig_path": pdf_path,
|
451 |
+
"num_pages": num_pages,
|
452 |
+
"document_text": document_text,
|
453 |
+
"page_texts": page_texts,
|
454 |
+
"fallback_pages": fallback_pages,
|
455 |
+
}
|
456 |
+
except Exception as e:
|
457 |
+
# Check for ExceptionGroup with BrokenProcessPool
|
458 |
+
if isinstance(e, ExceptionGroup):
|
459 |
+
broken_pool, other = e.split(BrokenProcessPool)
|
460 |
+
if broken_pool is not None: # Found at least one BrokenProcessPool
|
461 |
+
logger.critical("Encountered BrokenProcessPool, exiting process.")
|
462 |
+
sys.exit(1)
|
463 |
+
|
464 |
+
logger.exception(f"Exception in process_pdf for {pdf_path}: {e}")
|
465 |
+
return None
|
466 |
+
|
467 |
+
async def process_json(args, worker_id: int, json_path: str):
|
468 |
+
try:
|
469 |
+
json_data = json.load(open(json_path,'r'))
|
470 |
+
except:
|
471 |
+
logger.exception(f"Could not load {json_path}, aborting document")
|
472 |
+
try:
|
473 |
+
if args.task == 'merge_pages':
|
474 |
+
page_1 = json_data['page_1'].split("\n\n")
|
475 |
+
page_2 = json_data['page_2'].split("\n\n")
|
476 |
+
async with asyncio.TaskGroup() as tg:
|
477 |
+
task = tg.create_task(process_task(args, worker_id, task_name='element_merge_detect', task_args=(page_1, page_2)))
|
478 |
+
result = task.result()
|
479 |
+
return {
|
480 |
+
"orig_path": json_path,
|
481 |
+
"merge_pairs": result
|
482 |
+
}
|
483 |
+
elif args.task == 'merge_tables':
|
484 |
+
table_1 = json_data['table_1']
|
485 |
+
table_2 = json_data['table_2']
|
486 |
+
async with asyncio.TaskGroup() as tg:
|
487 |
+
task = tg.create_task(process_task(args, worker_id, task_name='html_table_merge', task_args=(table_1, table_2)))
|
488 |
+
result = task.result()
|
489 |
+
return {
|
490 |
+
"orig_path": json_path,
|
491 |
+
"merged_tables": result
|
492 |
+
}
|
493 |
+
else:
|
494 |
+
raise ValueError(f"Unknown task {args.task}")
|
495 |
+
|
496 |
+
except Exception as e:
|
497 |
+
# Check for ExceptionGroup with BrokenProcessPool
|
498 |
+
if isinstance(e, ExceptionGroup):
|
499 |
+
broken_pool, other = e.split(BrokenProcessPool)
|
500 |
+
if broken_pool is not None: # Found at least one BrokenProcessPool
|
501 |
+
logger.critical("Encountered BrokenProcessPool, exiting process.")
|
502 |
+
sys.exit(1)
|
503 |
+
|
504 |
+
logger.exception(f"Exception in process_json for {json_path}: {e}")
|
505 |
+
return None
|
506 |
+
|
507 |
+
async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
|
508 |
+
while True:
|
509 |
+
# Wait until allowed to proceed
|
510 |
+
await semaphore.acquire()
|
511 |
+
|
512 |
+
work_item = await work_queue.get_work()
|
513 |
+
|
514 |
+
if work_item is None:
|
515 |
+
logger.info(f"Worker {worker_id} exiting due to empty queue")
|
516 |
+
semaphore.release()
|
517 |
+
break
|
518 |
+
|
519 |
+
logger.info(f"Worker {worker_id} processing work item {work_item.hash}")
|
520 |
+
await tracker.clear_work(worker_id)
|
521 |
+
|
522 |
+
try:
|
523 |
+
async with asyncio.TaskGroup() as tg:
|
524 |
+
if args.task == 'pdf2markdown':
|
525 |
+
tasks = [tg.create_task(process_pdf(args, worker_id, pdf_path)) for pdf_path in work_item.work_paths]
|
526 |
+
elif args.task == 'merge_pages' or args.task == 'merge_tables':
|
527 |
+
tasks = [tg.create_task(process_json(args, worker_id, json_path)) for json_path in work_item.work_paths]
|
528 |
+
else:
|
529 |
+
raise ValueError(f"Unknown task {args.task}")
|
530 |
+
|
531 |
+
logger.info(f"Created all tasks for {work_item.hash}")
|
532 |
+
|
533 |
+
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
|
534 |
+
|
535 |
+
results = []
|
536 |
+
for task in tasks:
|
537 |
+
try:
|
538 |
+
result = task.result()
|
539 |
+
except:
|
540 |
+
pass
|
541 |
+
|
542 |
+
if result is not None:
|
543 |
+
results.append(result)
|
544 |
+
|
545 |
+
logger.info(f"Got {len(results)} docs for {work_item.hash}")
|
546 |
+
|
547 |
+
output_final_path = os.path.join(args.workspace, "results", f"output_{work_item.hash}.jsonl")
|
548 |
+
with open(output_final_path, "w") as f:
|
549 |
+
for result in results:
|
550 |
+
f.write(json.dumps(result))
|
551 |
+
f.write("\n")
|
552 |
+
|
553 |
+
await work_queue.mark_done(work_item)
|
554 |
+
except Exception as e:
|
555 |
+
logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}")
|
556 |
+
finally:
|
557 |
+
semaphore.release()
|
558 |
+
|
559 |
+
async def vllm_server_task(args, semaphore):
|
560 |
+
model_name_or_path = args.model
|
561 |
+
|
562 |
+
cmd = [
|
563 |
+
"vllm",
|
564 |
+
"serve",
|
565 |
+
model_name_or_path,
|
566 |
+
"--port",
|
567 |
+
str(args.port),
|
568 |
+
"--max-model-len",
|
569 |
+
str(args.model_max_context),
|
570 |
+
"--gpu_memory_utilization",
|
571 |
+
str(0.8)
|
572 |
+
]
|
573 |
+
|
574 |
+
proc = await asyncio.create_subprocess_exec(
|
575 |
+
*cmd,
|
576 |
+
stdout=asyncio.subprocess.PIPE,
|
577 |
+
stderr=asyncio.subprocess.PIPE,
|
578 |
+
)
|
579 |
+
|
580 |
+
# Ensure the subprocess is terminated on exit
|
581 |
+
def _kill_proc():
|
582 |
+
proc.terminate()
|
583 |
+
|
584 |
+
atexit.register(_kill_proc)
|
585 |
+
|
586 |
+
# Shared variables between tasks
|
587 |
+
last_running_req, last_queue_req = 0, 0
|
588 |
+
server_printed_ready_message = False
|
589 |
+
last_semaphore_release = time.time()
|
590 |
+
|
591 |
+
async def process_line(line):
|
592 |
+
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
|
593 |
+
vllm_logger.info(line)
|
594 |
+
|
595 |
+
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
|
596 |
+
# can see any warnings/errors more easily
|
597 |
+
if not server_printed_ready_message:
|
598 |
+
logger.info(line)
|
599 |
+
|
600 |
+
if "Detected errors during sampling" in line:
|
601 |
+
logger.error("Cannot continue, sampling errors detected, model is probably corrupt")
|
602 |
+
sys.exit(1)
|
603 |
+
|
604 |
+
# TODO, need to trace down this issue in vllm itself, but it will otherwise cause the server to lock up
|
605 |
+
if "IndexError: list index out of range" in line:
|
606 |
+
logger.error("IndexError in model, restarting server")
|
607 |
+
proc.terminate()
|
608 |
+
|
609 |
+
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
|
610 |
+
server_printed_ready_message = True
|
611 |
+
last_semaphore_release = time.time()
|
612 |
+
|
613 |
+
match = re.search(r"Running: (\d+)", line)
|
614 |
+
if match:
|
615 |
+
last_running_req = int(match.group(1))
|
616 |
+
|
617 |
+
match = re.search(r"(?:Waiting|Pending):\s*(\d+)", line)
|
618 |
+
if match:
|
619 |
+
last_queue_req = int(match.group(1))
|
620 |
+
logger.info(f"vllm running req: {last_running_req} queue req: {last_queue_req}")
|
621 |
+
|
622 |
+
async def read_stream(stream):
|
623 |
+
while True:
|
624 |
+
line = await stream.readline()
|
625 |
+
if not line:
|
626 |
+
break
|
627 |
+
try:
|
628 |
+
line = line.decode("utf-8").rstrip()
|
629 |
+
await process_line(line)
|
630 |
+
except Exception as ex:
|
631 |
+
logger.warning(f"Got {ex} when reading log line from inference server, skipping")
|
632 |
+
|
633 |
+
async def timeout_task():
|
634 |
+
nonlocal last_running_req, last_queue_req, last_semaphore_release
|
635 |
+
try:
|
636 |
+
while True:
|
637 |
+
await asyncio.sleep(1)
|
638 |
+
if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
|
639 |
+
semaphore.release()
|
640 |
+
last_semaphore_release = time.time()
|
641 |
+
logger.info("Semaphore released, allowing a worker to proceed.")
|
642 |
+
except asyncio.CancelledError:
|
643 |
+
pass # Clean up if the task is cancelled
|
644 |
+
|
645 |
+
# Start tasks to read stdout, stderr, and handle timeout logic
|
646 |
+
stdout_task = asyncio.create_task(read_stream(proc.stdout))
|
647 |
+
stderr_task = asyncio.create_task(read_stream(proc.stderr))
|
648 |
+
timeout_task = asyncio.create_task(timeout_task())
|
649 |
+
|
650 |
+
try:
|
651 |
+
await proc.wait()
|
652 |
+
except asyncio.CancelledError:
|
653 |
+
logger.info("Got cancellation request for VLLM server")
|
654 |
+
proc.terminate()
|
655 |
+
raise
|
656 |
+
|
657 |
+
timeout_task.cancel()
|
658 |
+
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
659 |
+
|
660 |
+
async def vllm_server_host(args, semaphore):
|
661 |
+
MAX_RETRIES = 5
|
662 |
+
retry = 0
|
663 |
+
|
664 |
+
while retry < MAX_RETRIES:
|
665 |
+
await vllm_server_task(args, semaphore)
|
666 |
+
logger.warning("VLLM server task ended")
|
667 |
+
retry += 1
|
668 |
+
|
669 |
+
if retry >= MAX_RETRIES:
|
670 |
+
logger.error(f"Ended up starting the vllm server more than {retry} times, cancelling pipeline")
|
671 |
+
logger.error("")
|
672 |
+
logger.error("Please make sure vllm is installed according to the latest instructions here: https://docs.vllm.ai/start/install.html")
|
673 |
+
sys.exit(1)
|
674 |
+
|
675 |
+
async def vllm_server_ready(args):
|
676 |
+
max_attempts = 300
|
677 |
+
delay_sec = 1
|
678 |
+
url = f"http://localhost:{args.port}/v1/models"
|
679 |
+
|
680 |
+
for attempt in range(1, max_attempts + 1):
|
681 |
+
try:
|
682 |
+
async with httpx.AsyncClient() as session:
|
683 |
+
response = await session.get(url)
|
684 |
+
|
685 |
+
if response.status_code == 200:
|
686 |
+
logger.info("vllm server is ready.")
|
687 |
+
return
|
688 |
+
else:
|
689 |
+
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
|
690 |
+
except Exception:
|
691 |
+
logger.warning(f"Attempt {attempt}: Please wait for vllm server to become ready...")
|
692 |
+
|
693 |
+
await asyncio.sleep(delay_sec)
|
694 |
+
|
695 |
+
raise Exception("vllm server did not become ready after waiting.")
|
696 |
+
|
697 |
+
async def download_model(model_name_or_path: str):
|
698 |
+
if os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path):
|
699 |
+
logger.info(f"Using local model path at '{model_name_or_path}'")
|
700 |
+
else:
|
701 |
+
logger.info(f"Downloading model with hugging face '{model_name_or_path}'")
|
702 |
+
snapshot_download(repo_id=model_name_or_path)
|
703 |
+
|
704 |
+
async def metrics_reporter(work_queue):
|
705 |
+
while True:
|
706 |
+
# Leading newlines preserve table formatting in logs
|
707 |
+
logger.info(f"Queue remaining: {work_queue.size}")
|
708 |
+
logger.info("\n" + str(metrics))
|
709 |
+
logger.info("\n" + str(await tracker.get_status_table()))
|
710 |
+
await asyncio.sleep(10)
|
711 |
+
|
712 |
+
async def main():
|
713 |
+
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline")
|
714 |
+
parser.add_argument(
|
715 |
+
"workspace",
|
716 |
+
help="The filesystem path where work will be stored, can be a local folder",
|
717 |
+
)
|
718 |
+
|
719 |
+
parser.add_argument("--task", type=str, choices=['pdf2markdown','merge_pages','merge_tables'], default='pdf2markdown', help="task names, could be 'pdf2markdown', 'merge_pages' or 'merge_tables'")
|
720 |
+
|
721 |
+
parser.add_argument(
|
722 |
+
"--data",
|
723 |
+
nargs="*",
|
724 |
+
help="List of paths to files to process",
|
725 |
+
default=None,
|
726 |
+
)
|
727 |
+
|
728 |
+
parser.add_argument("--pages_per_group", type=int, default=500, help="Aiming for this many pdf pages per work item group")
|
729 |
+
parser.add_argument("--max_page_retries", type=int, default=8, help="Max number of times we will retry rendering a page")
|
730 |
+
parser.add_argument("--max_page_error_rate", type=float, default=0.004, help="Rate of allowable failed pages in a document, 1/250 by default")
|
731 |
+
parser.add_argument("--workers", type=int, default=8, help="Number of workers to run at a time")
|
732 |
+
|
733 |
+
# Model parameters
|
734 |
+
parser.add_argument(
|
735 |
+
"--model",
|
736 |
+
help="The path to the model",
|
737 |
+
default="ChatDOC/OCRFlux-3B",
|
738 |
+
)
|
739 |
+
parser.add_argument("--model_max_context", type=int, default=16384, help="Maximum context length that the model was fine tuned under")
|
740 |
+
parser.add_argument("--model_chat_template", type=str, default="qwen2-vl", help="Chat template to pass to vllm server")
|
741 |
+
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1024)
|
742 |
+
|
743 |
+
parser.add_argument("--skip_cross_page_merge", action="store_true", help="Whether to skip cross-page merging")
|
744 |
+
|
745 |
+
parser.add_argument("--port", type=int, default=40078, help="Port to use for the VLLM server")
|
746 |
+
args = parser.parse_args()
|
747 |
+
|
748 |
+
if os.path.exists(args.workspace):
|
749 |
+
shutil.rmtree(args.workspace)
|
750 |
+
|
751 |
+
# We need poppler to load the initial pdfs, even if we are not processing them here
|
752 |
+
check_poppler_version()
|
753 |
+
|
754 |
+
work_queue = LocalWorkQueue(args.workspace)
|
755 |
+
|
756 |
+
if args.task == 'pdf2markdown':
|
757 |
+
pdf_work_paths = set()
|
758 |
+
|
759 |
+
for pdf_path in args.data:
|
760 |
+
if os.path.exists(pdf_path):
|
761 |
+
if pdf_path.lower().endswith(".pdf") and open(pdf_path, "rb").read(4) == b"%PDF":
|
762 |
+
logger.info(f"Loading file at {pdf_path} as PDF document")
|
763 |
+
pdf_work_paths.add(pdf_path)
|
764 |
+
elif is_image(pdf_path):
|
765 |
+
logger.info(f"Loading file at {pdf_path} as image document")
|
766 |
+
pdf_work_paths.add(pdf_path)
|
767 |
+
else:
|
768 |
+
raise ValueError(f"Unsupported file extension for {pdf_path}")
|
769 |
+
else:
|
770 |
+
raise ValueError(f"{pdf_path} does not exist")
|
771 |
+
|
772 |
+
logger.info(f"Found {len(pdf_work_paths):,} total pdf paths to add")
|
773 |
+
|
774 |
+
# Estimate average pages per pdf
|
775 |
+
sample_size = min(100, len(pdf_work_paths))
|
776 |
+
sampled_pdfs = random.sample(list(pdf_work_paths), sample_size)
|
777 |
+
page_counts = []
|
778 |
+
|
779 |
+
for pdf_path in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
|
780 |
+
try:
|
781 |
+
if pdf_path.lower().endswith(".pdf"):
|
782 |
+
reader = PdfReader(pdf_path)
|
783 |
+
page_counts.append(len(reader.pages))
|
784 |
+
else:
|
785 |
+
page_counts.append(1)
|
786 |
+
except Exception as e:
|
787 |
+
logger.warning(f"Failed to read {pdf_path}: {e}")
|
788 |
+
|
789 |
+
if page_counts:
|
790 |
+
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
|
791 |
+
else:
|
792 |
+
logger.warning("Could not read any PDFs to estimate average page count.")
|
793 |
+
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
|
794 |
+
|
795 |
+
items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf))
|
796 |
+
logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
|
797 |
+
|
798 |
+
# Now call populate_queue
|
799 |
+
await work_queue.populate_queue(pdf_work_paths, items_per_group)
|
800 |
+
elif args.task == 'merge_pages' or args.task == 'merge_tables':
|
801 |
+
json_work_paths = set()
|
802 |
+
for json_path in args.data:
|
803 |
+
if os.path.exists(json_path):
|
804 |
+
if json_path.lower().endswith(".json"):
|
805 |
+
json_work_paths.add(json_path)
|
806 |
+
elif json_path.lower().endswith(".txt"):
|
807 |
+
logger.info(f"Loading file at {json_path} as list of paths")
|
808 |
+
with open(json_path, "r") as f:
|
809 |
+
json_work_paths |= set(filter(None, (line.strip() for line in f)))
|
810 |
+
else:
|
811 |
+
raise ValueError(f"Unsupported file extension for {json_path}")
|
812 |
+
else:
|
813 |
+
raise ValueError(f"{json_path} does not exist")
|
814 |
+
|
815 |
+
# Now call populate_queue
|
816 |
+
await work_queue.populate_queue(json_work_paths, args.pages_per_group)
|
817 |
+
|
818 |
+
|
819 |
+
# If you get this far, then you are doing inference and need a GPU
|
820 |
+
check_vllm_version()
|
821 |
+
check_torch_gpu_available()
|
822 |
+
|
823 |
+
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
824 |
+
|
825 |
+
# Download the model before you do anything else
|
826 |
+
await download_model(args.model)
|
827 |
+
|
828 |
+
# Initialize the work queue
|
829 |
+
qsize = await work_queue.initialize_queue()
|
830 |
+
|
831 |
+
if qsize == 0:
|
832 |
+
logger.info("No work to do, exiting")
|
833 |
+
return
|
834 |
+
# Create a semaphore to control worker access
|
835 |
+
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
836 |
+
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
837 |
+
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
838 |
+
semaphore = asyncio.Semaphore(1)
|
839 |
+
|
840 |
+
vllm_server = asyncio.create_task(vllm_server_host(args, semaphore))
|
841 |
+
|
842 |
+
await vllm_server_ready(args)
|
843 |
+
|
844 |
+
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
845 |
+
|
846 |
+
# Create worker tasks to process the queue concurrently.
|
847 |
+
worker_tasks = []
|
848 |
+
for i in range(args.workers):
|
849 |
+
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
|
850 |
+
worker_tasks.append(task)
|
851 |
+
|
852 |
+
# Wait for all worker tasks to finish
|
853 |
+
await asyncio.gather(*worker_tasks)
|
854 |
+
|
855 |
+
vllm_server.cancel()
|
856 |
+
metrics_task.cancel()
|
857 |
+
logger.info("Work done")
|
858 |
+
|
859 |
+
|
860 |
+
if __name__ == "__main__":
|
861 |
+
asyncio.run(main())
|
ocrflux/prompts.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
@dataclass(frozen=True)
|
6 |
+
class PageResponse:
|
7 |
+
primary_language: Optional[str]
|
8 |
+
is_rotation_valid: bool
|
9 |
+
rotation_correction: int
|
10 |
+
is_table: bool
|
11 |
+
is_diagram: bool
|
12 |
+
natural_text: Optional[str]
|
13 |
+
|
14 |
+
def __post_init__(self):
|
15 |
+
# Validate rotation_correction is one of the allowed values
|
16 |
+
if self.rotation_correction not in {0, 90, 180, 270}:
|
17 |
+
raise ValueError("rotation_correction must be one of [0, 90, 180, 270].")
|
18 |
+
|
19 |
+
# Type checks
|
20 |
+
if not isinstance(self.primary_language, (str, type(None))):
|
21 |
+
raise TypeError("primary_language must be of type Optional[str].")
|
22 |
+
if not isinstance(self.is_rotation_valid, bool):
|
23 |
+
raise TypeError("is_rotation_valid must be of type bool.")
|
24 |
+
if not isinstance(self.rotation_correction, int):
|
25 |
+
raise TypeError("rotation_correction must be of type int.")
|
26 |
+
if not isinstance(self.is_table, bool):
|
27 |
+
raise TypeError("is_table must be of type bool.")
|
28 |
+
if not isinstance(self.is_diagram, bool):
|
29 |
+
raise TypeError("is_diagram must be of type bool.")
|
30 |
+
if not isinstance(self.natural_text, (str, type(None))):
|
31 |
+
raise TypeError("natural_text must be of type Optional[str].")
|
32 |
+
|
33 |
+
def build_element_merge_detect_prompt(text_list_1,text_list_2) -> str:
|
34 |
+
task = '''Below are two consecutive pages in Markdown format, where each element of them is numbered. Identify pairs of elements which should be merged across the two pages, such as text paragraphs or tables that span across the two pages. Return pairs as [(element_index_of_page1, element_index_of_page2), ...] or [] if no elements should be merged.\n'''
|
35 |
+
task += "Previous page:\n"
|
36 |
+
for i,text in enumerate(text_list_1):
|
37 |
+
task += f"{i}. {text}\n\n"
|
38 |
+
task += "Next page:\n"
|
39 |
+
for i,text in enumerate(text_list_2):
|
40 |
+
task += f"{i}. {text}\n\n"
|
41 |
+
return task
|
42 |
+
|
43 |
+
def build_html_table_merge_prompt(table1,table2) -> str:
|
44 |
+
return (
|
45 |
+
f"Below are two tables in HTML format, merge them into one table in HTML format.\n"
|
46 |
+
f"TABLE 1:\n"
|
47 |
+
f"{table1}\n"
|
48 |
+
f"TABLE 2:\n"
|
49 |
+
f"{table2}\n"
|
50 |
+
)
|
51 |
+
|
52 |
+
def build_page_to_markdown_prompt() -> str:
|
53 |
+
return (
|
54 |
+
f"Below is the image of one page of a document. "
|
55 |
+
f"Just return the plain text representation of this document as if you were reading it naturally.\n"
|
56 |
+
f"ALL tables should be presented in HTML format.\n"
|
57 |
+
f"If there are images or figures in the page, present them as \"<Image>(left,top),(right,bottom)</Image>\", (left,top,right,bottom) are the coordinates of the top-left and bottom-right corners of the image or figure.\n"
|
58 |
+
f"Present all titles and headings as H1 headings.\n"
|
59 |
+
f"Do not hallucinate.\n"
|
60 |
+
)
|
ocrflux/table_format.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import re
|
4 |
+
|
5 |
+
def is_html_table(text):
|
6 |
+
soup = BeautifulSoup(text, "html.parser")
|
7 |
+
return soup.find('table') is not None
|
8 |
+
|
9 |
+
def table_matrix2html(matrix_table):
|
10 |
+
soup = BeautifulSoup(matrix_table, 'html.parser')
|
11 |
+
table = soup.find('table')
|
12 |
+
rownum = 0
|
13 |
+
colnum = 0
|
14 |
+
cell_dict = {}
|
15 |
+
rid = 0
|
16 |
+
for tr in table.find_all('tr'):
|
17 |
+
cid = 0
|
18 |
+
for td in tr.find_all('td'):
|
19 |
+
if td.find('l'):
|
20 |
+
cell_dict[(rid, cid)] = '<l>'
|
21 |
+
elif td.find('t'):
|
22 |
+
cell_dict[(rid, cid)] = '<t>'
|
23 |
+
elif td.find('lt'):
|
24 |
+
cell_dict[(rid, cid)] = '<lt>'
|
25 |
+
else:
|
26 |
+
text = td.get_text(strip=True)
|
27 |
+
cell_dict[(rid, cid)] = text
|
28 |
+
cid += 1
|
29 |
+
if colnum == 0:
|
30 |
+
colnum = cid
|
31 |
+
elif cid != colnum:
|
32 |
+
raise Exception('colnum not match')
|
33 |
+
rid += 1
|
34 |
+
rownum = rid
|
35 |
+
html_table = ['<table>']
|
36 |
+
for rid in range(rownum):
|
37 |
+
html_table.append('<tr>')
|
38 |
+
for cid in range(colnum):
|
39 |
+
if (rid, cid) not in cell_dict.keys():
|
40 |
+
continue
|
41 |
+
text = cell_dict[(rid, cid)]
|
42 |
+
if text == '<l>' or text == '<t>' or text == '<lt>':
|
43 |
+
raise Exception('cell not match')
|
44 |
+
rowspan = 1
|
45 |
+
colspan = 1
|
46 |
+
for r in range(rid+1, rownum):
|
47 |
+
if (r, cid) in cell_dict.keys() and cell_dict[(r, cid)] == '<t>':
|
48 |
+
rowspan += 1
|
49 |
+
del cell_dict[(r, cid)]
|
50 |
+
else:
|
51 |
+
break
|
52 |
+
for c in range(cid+1, colnum):
|
53 |
+
if (rid, c) in cell_dict.keys() and cell_dict[(rid, c)] == '<l>':
|
54 |
+
colspan += 1
|
55 |
+
del cell_dict[(rid, c)]
|
56 |
+
else:
|
57 |
+
break
|
58 |
+
for r in range(rid+1, rid+rowspan):
|
59 |
+
for c in range(cid+1, cid+colspan):
|
60 |
+
if cell_dict[(r, c)] != '<lt>':
|
61 |
+
raise Exception('cell not match')
|
62 |
+
del cell_dict[(r, c)]
|
63 |
+
attr = ''
|
64 |
+
if rowspan > 1:
|
65 |
+
attr += ' rowspan="{}"'.format(rowspan)
|
66 |
+
if colspan > 1:
|
67 |
+
attr += ' colspan="{}"'.format(colspan)
|
68 |
+
html_table.append("<td{}>{}</td>".format(attr, text))
|
69 |
+
html_table.append('</tr>')
|
70 |
+
html_table.append('</table>')
|
71 |
+
return "".join(html_table)
|
72 |
+
|
73 |
+
def table_html2matrix(html_table):
|
74 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
75 |
+
table = soup.find('table')
|
76 |
+
rownum = len(table.find_all('tr'))
|
77 |
+
colnum = 0
|
78 |
+
tr = table.find_all('tr')[0]
|
79 |
+
for td in tr.find_all('td'):
|
80 |
+
colnum += td.get('colspan', 1)
|
81 |
+
matrix = [[None for _ in range(colnum)] for _ in range(rownum)]
|
82 |
+
|
83 |
+
rid = 0
|
84 |
+
for tr in table.find_all('tr'):
|
85 |
+
cid = 0
|
86 |
+
for td in tr.find_all('td'):
|
87 |
+
for c in range(cid, colnum):
|
88 |
+
if matrix[rid][c] is None:
|
89 |
+
break
|
90 |
+
cid = c
|
91 |
+
rowspan = td.get('rowspan', 1)
|
92 |
+
colspan = td.get('colspan', 1)
|
93 |
+
cell_text = td.get_text(strip=True)
|
94 |
+
for r in range(rid,rid+rowspan):
|
95 |
+
if r >= rownum:
|
96 |
+
raise Exception('rownum not match')
|
97 |
+
for c in range(cid,cid+colspan):
|
98 |
+
if c >= colnum:
|
99 |
+
raise Exception('colnum not match')
|
100 |
+
if matrix[r][c] is not None:
|
101 |
+
raise Exception('cell not match')
|
102 |
+
if r == rid and c == cid:
|
103 |
+
matrix[r][c] = cell_text
|
104 |
+
elif r == rid:
|
105 |
+
matrix[r][c] = '<l>'
|
106 |
+
elif c == cid:
|
107 |
+
matrix[r][c] = '<t>'
|
108 |
+
else:
|
109 |
+
matrix[r][c] = '<lt>'
|
110 |
+
cid += colspan
|
111 |
+
rid += 1
|
112 |
+
|
113 |
+
matrix_table = ['<table>']
|
114 |
+
for rid in range(rownum):
|
115 |
+
matrix_table.append('<tr>')
|
116 |
+
for cid in range(colnum):
|
117 |
+
matrix_table.append('<td>')
|
118 |
+
cell_text = matrix[rid][cid]
|
119 |
+
matrix_table.append(cell_text)
|
120 |
+
matrix_table.append('</td>')
|
121 |
+
matrix_table.append('</tr>')
|
122 |
+
matrix_table.append('</table>')
|
123 |
+
return "".join(matrix_table)
|
124 |
+
|
125 |
+
trans_func = {
|
126 |
+
"html2matrix": table_html2matrix,
|
127 |
+
"matrix2html": table_matrix2html,
|
128 |
+
}
|
129 |
+
|
130 |
+
def trans_markdown_text(markdown_text,trans_type):
|
131 |
+
if markdown_text == None:
|
132 |
+
return None
|
133 |
+
text_list = markdown_text.split('\n\n')
|
134 |
+
for i,text in enumerate(text_list):
|
135 |
+
if is_html_table(text):
|
136 |
+
text_list[i] = trans_func[trans_type](text)
|
137 |
+
return "\n\n".join(text_list)
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
ocrflux/work_queue.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import asyncio
|
3 |
+
import datetime
|
4 |
+
import hashlib
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
from asyncio import Queue
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Any, List, Optional
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class WorkItem:
|
17 |
+
"""Represents a single work item in the queue"""
|
18 |
+
|
19 |
+
hash: str
|
20 |
+
work_paths: List[str]
|
21 |
+
|
22 |
+
|
23 |
+
class WorkQueue(abc.ABC):
|
24 |
+
"""
|
25 |
+
Base class defining the interface for a work queue.
|
26 |
+
"""
|
27 |
+
|
28 |
+
@abc.abstractmethod
|
29 |
+
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
30 |
+
"""
|
31 |
+
Add new items to the work queue. The specifics will vary depending on
|
32 |
+
whether this is a local or S3-backed queue.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
work_paths: Each individual path that we will process over
|
36 |
+
items_per_group: Number of items to group together in a single work item
|
37 |
+
"""
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abc.abstractmethod
|
41 |
+
async def initialize_queue(self) -> int:
|
42 |
+
"""
|
43 |
+
Load the work queue from the relevant store (local or remote)
|
44 |
+
and initialize it for processing.
|
45 |
+
|
46 |
+
For example, this might remove already completed work items and randomize
|
47 |
+
the order before adding them to an internal queue.
|
48 |
+
"""
|
49 |
+
pass
|
50 |
+
|
51 |
+
@abc.abstractmethod
|
52 |
+
async def is_completed(self, work_hash: str) -> bool:
|
53 |
+
"""
|
54 |
+
Check if a work item has been completed.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
work_hash: Hash of the work item to check
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
True if the work is completed, False otherwise
|
61 |
+
"""
|
62 |
+
pass
|
63 |
+
|
64 |
+
@abc.abstractmethod
|
65 |
+
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
66 |
+
"""
|
67 |
+
Get the next available work item that isn't completed or locked.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
worker_lock_timeout_secs: Number of seconds before considering
|
71 |
+
a worker lock stale (default 30 mins)
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
WorkItem if work is available, None if queue is empty
|
75 |
+
"""
|
76 |
+
pass
|
77 |
+
|
78 |
+
@abc.abstractmethod
|
79 |
+
async def mark_done(self, work_item: WorkItem) -> None:
|
80 |
+
"""
|
81 |
+
Mark a work item as done by removing its lock file
|
82 |
+
or performing any other cleanup.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
work_item: The WorkItem to mark as done
|
86 |
+
"""
|
87 |
+
pass
|
88 |
+
|
89 |
+
@property
|
90 |
+
@abc.abstractmethod
|
91 |
+
def size(self) -> int:
|
92 |
+
"""Get current size of work queue"""
|
93 |
+
pass
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def _compute_workgroup_hash(work_paths: List[str]) -> str:
|
97 |
+
"""
|
98 |
+
Compute a deterministic hash for a group of paths.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
work_paths: List of paths (local or S3)
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
SHA1 hash of the sorted paths
|
105 |
+
"""
|
106 |
+
sha1 = hashlib.sha1()
|
107 |
+
for path in sorted(work_paths):
|
108 |
+
sha1.update(path.encode("utf-8"))
|
109 |
+
return sha1.hexdigest()
|
110 |
+
|
111 |
+
|
112 |
+
# --------------------------------------------------------------------------------------
|
113 |
+
# Local Helpers for reading/writing the index CSV (compressed with zstd) to disk
|
114 |
+
# --------------------------------------------------------------------------------------
|
115 |
+
|
116 |
+
try:
|
117 |
+
import zstandard
|
118 |
+
except ImportError:
|
119 |
+
zstandard = None
|
120 |
+
|
121 |
+
|
122 |
+
def download_zstd_csv_local(local_path: str) -> List[str]:
|
123 |
+
"""
|
124 |
+
Download a zstd-compressed CSV from a local path.
|
125 |
+
If the file doesn't exist, returns an empty list.
|
126 |
+
"""
|
127 |
+
if not os.path.exists(local_path):
|
128 |
+
return []
|
129 |
+
|
130 |
+
if not zstandard:
|
131 |
+
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
|
132 |
+
|
133 |
+
with open(local_path, "rb") as f:
|
134 |
+
dctx = zstandard.ZstdDecompressor()
|
135 |
+
data = dctx.decompress(f.read())
|
136 |
+
lines = data.decode("utf-8").splitlines()
|
137 |
+
return lines
|
138 |
+
|
139 |
+
|
140 |
+
def upload_zstd_csv_local(local_path: str, lines: List[str]) -> None:
|
141 |
+
"""
|
142 |
+
Upload a zstd-compressed CSV to a local path.
|
143 |
+
"""
|
144 |
+
if not zstandard:
|
145 |
+
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
|
146 |
+
|
147 |
+
data = "\n".join(lines).encode("utf-8")
|
148 |
+
cctx = zstandard.ZstdCompressor()
|
149 |
+
compressed_data = cctx.compress(data)
|
150 |
+
|
151 |
+
# Ensure parent directories exist
|
152 |
+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
153 |
+
|
154 |
+
with open(local_path, "wb") as f:
|
155 |
+
f.write(compressed_data)
|
156 |
+
|
157 |
+
|
158 |
+
# --------------------------------------------------------------------------------------
|
159 |
+
# LocalWorkQueue Implementation
|
160 |
+
# --------------------------------------------------------------------------------------
|
161 |
+
|
162 |
+
|
163 |
+
class LocalWorkQueue(WorkQueue):
|
164 |
+
"""
|
165 |
+
A local in-memory and on-disk WorkQueue implementation, which uses
|
166 |
+
a local workspace directory to store the queue index, lock files,
|
167 |
+
and completed results for persistent resumption across process restarts.
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self, workspace_path: str):
|
171 |
+
"""
|
172 |
+
Initialize the local work queue.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
workspace_path: Local directory path where the queue index,
|
176 |
+
results, and locks are stored.
|
177 |
+
"""
|
178 |
+
self.workspace_path = os.path.abspath(workspace_path)
|
179 |
+
os.makedirs(self.workspace_path, exist_ok=True)
|
180 |
+
|
181 |
+
# Local index file (compressed)
|
182 |
+
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
183 |
+
|
184 |
+
# Output directory for completed tasks
|
185 |
+
self._results_dir = os.path.join(self.workspace_path, "results")
|
186 |
+
os.makedirs(self._results_dir, exist_ok=True)
|
187 |
+
|
188 |
+
# Directory for lock files
|
189 |
+
self._locks_dir = os.path.join(self.workspace_path, "worker_locks")
|
190 |
+
os.makedirs(self._locks_dir, exist_ok=True)
|
191 |
+
|
192 |
+
# Internal queue
|
193 |
+
self._queue: Queue[Any] = Queue()
|
194 |
+
|
195 |
+
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
196 |
+
"""
|
197 |
+
Add new items to the work queue (local version).
|
198 |
+
|
199 |
+
Args:
|
200 |
+
work_paths: Each individual path (local in this context)
|
201 |
+
that we will process over
|
202 |
+
items_per_group: Number of items to group together in a single work item
|
203 |
+
"""
|
204 |
+
# Treat them as local paths, but keep variable name for consistency
|
205 |
+
all_paths = set(work_paths)
|
206 |
+
logger.info(f"Found {len(all_paths):,} total paths")
|
207 |
+
|
208 |
+
# Load existing work groups from local index
|
209 |
+
existing_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
|
210 |
+
existing_groups = {}
|
211 |
+
for line in existing_lines:
|
212 |
+
if line.strip():
|
213 |
+
parts = line.strip().split(",")
|
214 |
+
group_hash = parts[0]
|
215 |
+
group_paths = parts[1:]
|
216 |
+
existing_groups[group_hash] = group_paths
|
217 |
+
|
218 |
+
existing_path_set = {p for paths in existing_groups.values() for p in paths}
|
219 |
+
new_paths = all_paths - existing_path_set
|
220 |
+
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
|
221 |
+
|
222 |
+
if not new_paths:
|
223 |
+
return
|
224 |
+
|
225 |
+
# Create new work groups
|
226 |
+
new_groups = []
|
227 |
+
current_group = []
|
228 |
+
for path in sorted(new_paths):
|
229 |
+
current_group.append(path)
|
230 |
+
if len(current_group) == items_per_group:
|
231 |
+
group_hash = self._compute_workgroup_hash(current_group)
|
232 |
+
new_groups.append((group_hash, current_group))
|
233 |
+
current_group = []
|
234 |
+
if current_group:
|
235 |
+
group_hash = self._compute_workgroup_hash(current_group)
|
236 |
+
new_groups.append((group_hash, current_group))
|
237 |
+
|
238 |
+
logger.info(f"Created {len(new_groups):,} new work groups")
|
239 |
+
|
240 |
+
# Combine and save updated work groups
|
241 |
+
combined_groups = existing_groups.copy()
|
242 |
+
for group_hash, group_paths in new_groups:
|
243 |
+
combined_groups[group_hash] = group_paths
|
244 |
+
|
245 |
+
combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
|
246 |
+
|
247 |
+
if new_groups:
|
248 |
+
# Write the combined data back to disk in zstd CSV format
|
249 |
+
await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines)
|
250 |
+
|
251 |
+
async def initialize_queue(self) -> int:
|
252 |
+
"""
|
253 |
+
Load the work queue from the local index file and initialize it for processing.
|
254 |
+
Removes already completed work items and randomizes the order.
|
255 |
+
"""
|
256 |
+
# 1) Read the index
|
257 |
+
work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
|
258 |
+
work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()}
|
259 |
+
|
260 |
+
# 2) Determine which items are completed by scanning local results/*.jsonl
|
261 |
+
if not os.path.isdir(self._results_dir):
|
262 |
+
os.makedirs(self._results_dir, exist_ok=True)
|
263 |
+
done_work_items = [f for f in os.listdir(self._results_dir) if f.startswith("output_") and f.endswith(".jsonl")]
|
264 |
+
done_work_hashes = {fn[len("output_") : -len(".jsonl")] for fn in done_work_items}
|
265 |
+
|
266 |
+
# 3) Filter out completed items
|
267 |
+
remaining_work_hashes = set(work_queue) - done_work_hashes
|
268 |
+
remaining_items = [WorkItem(hash=hash_, work_paths=work_queue[hash_]) for hash_ in remaining_work_hashes]
|
269 |
+
random.shuffle(remaining_items)
|
270 |
+
|
271 |
+
# 4) Initialize our in-memory queue
|
272 |
+
self._queue = asyncio.Queue()
|
273 |
+
for item in remaining_items:
|
274 |
+
await self._queue.put(item)
|
275 |
+
|
276 |
+
logger.info(f"Initialized local queue with {self._queue.qsize()} work items")
|
277 |
+
|
278 |
+
return self._queue.qsize()
|
279 |
+
|
280 |
+
async def is_completed(self, work_hash: str) -> bool:
|
281 |
+
"""
|
282 |
+
Check if a work item has been completed locally by seeing if
|
283 |
+
output_{work_hash}.jsonl is present in the results directory.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
work_hash: Hash of the work item to check
|
287 |
+
"""
|
288 |
+
output_file = os.path.join(self._results_dir, f"output_{work_hash}.jsonl")
|
289 |
+
return os.path.exists(output_file)
|
290 |
+
|
291 |
+
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
292 |
+
"""
|
293 |
+
Get the next available work item that isn't completed or locked.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
worker_lock_timeout_secs: Number of seconds before considering
|
297 |
+
a worker lock stale (default 30 mins)
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
WorkItem if work is available, None if queue is empty
|
301 |
+
"""
|
302 |
+
while True:
|
303 |
+
try:
|
304 |
+
work_item = self._queue.get_nowait()
|
305 |
+
except asyncio.QueueEmpty:
|
306 |
+
return None
|
307 |
+
|
308 |
+
# Check if work is already completed
|
309 |
+
if await self.is_completed(work_item.hash):
|
310 |
+
logger.debug(f"Work item {work_item.hash} already completed, skipping")
|
311 |
+
self._queue.task_done()
|
312 |
+
continue
|
313 |
+
|
314 |
+
# Check for worker lock
|
315 |
+
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
|
316 |
+
if os.path.exists(lock_file):
|
317 |
+
# Check modification time
|
318 |
+
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(lock_file), datetime.timezone.utc)
|
319 |
+
if (datetime.datetime.now(datetime.timezone.utc) - mtime).total_seconds() > worker_lock_timeout_secs:
|
320 |
+
# Lock is stale, we can take this work
|
321 |
+
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
|
322 |
+
else:
|
323 |
+
# Lock is active, skip this work
|
324 |
+
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
|
325 |
+
self._queue.task_done()
|
326 |
+
continue
|
327 |
+
|
328 |
+
# Create our lock file (touch an empty file)
|
329 |
+
try:
|
330 |
+
with open(lock_file, "wb") as f:
|
331 |
+
f.write(b"")
|
332 |
+
except Exception as e:
|
333 |
+
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
|
334 |
+
self._queue.task_done()
|
335 |
+
continue
|
336 |
+
|
337 |
+
return work_item
|
338 |
+
|
339 |
+
async def mark_done(self, work_item: WorkItem) -> None:
|
340 |
+
"""
|
341 |
+
Mark a work item as done by removing its lock file.
|
342 |
+
|
343 |
+
Args:
|
344 |
+
work_item: The WorkItem to mark as done
|
345 |
+
"""
|
346 |
+
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
|
347 |
+
if os.path.exists(lock_file):
|
348 |
+
try:
|
349 |
+
os.remove(lock_file)
|
350 |
+
except Exception as e:
|
351 |
+
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
|
352 |
+
self._queue.task_done()
|
353 |
+
|
354 |
+
@property
|
355 |
+
def size(self) -> int:
|
356 |
+
"""Get current size of local work queue"""
|
357 |
+
return self._queue.qsize()
|
pyproject.toml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools", "wheel"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "ocrflux"
|
7 |
+
description = "Fast, efficient, and high quality OCR powered by open visual language models"
|
8 |
+
version = "0.1.0"
|
9 |
+
readme = "README.md"
|
10 |
+
classifiers = [
|
11 |
+
"Intended Audience :: Science/Research",
|
12 |
+
"Development Status :: 3 - Alpha",
|
13 |
+
"License :: OSI Approved :: Apache Software License",
|
14 |
+
"Programming Language :: Python :: 3",
|
15 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
16 |
+
]
|
17 |
+
authors = [
|
18 |
+
{name = "Yu Tang", email = "[email protected]"}
|
19 |
+
]
|
20 |
+
requires-python = ">=3.11"
|
21 |
+
dependencies = [
|
22 |
+
"cached-path",
|
23 |
+
"smart_open",
|
24 |
+
"pypdf>=5.2.0",
|
25 |
+
"pypdfium2",
|
26 |
+
"cryptography",
|
27 |
+
"lingua-language-detector",
|
28 |
+
"Pillow",
|
29 |
+
"ftfy",
|
30 |
+
"bleach",
|
31 |
+
"markdown2",
|
32 |
+
"filelock",
|
33 |
+
"orjson",
|
34 |
+
"requests",
|
35 |
+
"zstandard",
|
36 |
+
"boto3",
|
37 |
+
"httpx",
|
38 |
+
"torch>=2.5.1",
|
39 |
+
"transformers==4.50.0",
|
40 |
+
"vllm==0.7.3",
|
41 |
+
"img2pdf",
|
42 |
+
"nltk",
|
43 |
+
"bs4",
|
44 |
+
"distance",
|
45 |
+
"apted",
|
46 |
+
"gradio",
|
47 |
+
"gradio_pdf",
|
48 |
+
]
|
49 |
+
license = {file = "LICENSE"}
|
50 |
+
|
51 |
+
[project.urls]
|
52 |
+
Homepage = "https://github.com/chatdoc-com/OCRFlux"
|
53 |
+
Repository = "https://github.com/chatdoc-com/OCRFlux"
|
54 |
+
|
55 |
+
[tool.setuptools.packages.find]
|
56 |
+
exclude = [
|
57 |
+
"*.tests",
|
58 |
+
"*.tests.*",
|
59 |
+
"tests.*",
|
60 |
+
"tests",
|
61 |
+
"docs*",
|
62 |
+
"scripts*",
|
63 |
+
"images*"
|
64 |
+
]
|
65 |
+
|
66 |
+
[tool.setuptools]
|
67 |
+
include-package-data = true
|
68 |
+
|
69 |
+
[tool.setuptools.package-data]
|
70 |
+
ocrflux = [
|
71 |
+
"py.typed",
|
72 |
+
]
|
73 |
+
|
74 |
+
[tool.black]
|
75 |
+
line-length = 79
|