mirnaresearch commited on
Commit
ca5b08e
·
0 Parent(s):

Initial commit for HF Space (no images)

Browse files
.dockerignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *
2
+ **/*
3
+ .*
4
+ !ocrflux
5
+ !pyproject.toml
.github/workflows/docker.yml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ on:
2
+ push:
3
+ branches:
4
+ - main
5
+ - 'v*.*.*'
6
+
7
+ jobs:
8
+ build_and_push_docker:
9
+ runs-on: ubuntu-latest
10
+
11
+ steps:
12
+ - name: Checkout repository
13
+ uses: actions/checkout@v4
14
+
15
+ - name: Set up Docker Buildx
16
+ uses: docker/setup-buildx-action@v3
17
+
18
+ - name: Log in to Docker Hub
19
+ uses: docker/login-action@v3
20
+ with:
21
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
22
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
23
+
24
+ - name: Determine image tags
25
+ id: determine_tags
26
+ run: |
27
+ BRANCH_NAME=${{ github.ref_name }}
28
+ DOCKER_IMAGE_NAME="chatdoc/ocrflux"
29
+
30
+ if [[ "$BRANCH_NAME" == "main" ]]; then
31
+ echo "IMAGE_TAGS=$DOCKER_IMAGE_NAME:latest,$DOCKER_IMAGE_NAME:$BRANCH_NAME"
32
+ echo "image_tags=$DOCKER_IMAGE_NAME:latest,$DOCKER_IMAGE_NAME:$BRANCH_NAME" >> $GITHUB_OUTPUT
33
+ else
34
+ echo "IMAGE_TAGS=$DOCKER_IMAGE_NAME:$BRANCH_NAME"
35
+ echo "image_tags=$DOCKER_IMAGE_NAME:$BRANCH_NAME" >> $GITHUB_OUTPUT
36
+ fi
37
+
38
+ - name: Build and push Docker image
39
+ id: docker_build
40
+ uses: docker/build-push-action@v6
41
+ with:
42
+ context: .
43
+ push: true
44
+ tags: ${{ steps.determine_tags.outputs.image_tags }}
45
+ cache-from: type=gha,scope=${{ github.workflow }}
46
+ cache-to: type=gha,scope=${{ github.workflow }},mode=max
.gitignore ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Abstra
171
+ # Abstra is an AI-powered process automation framework.
172
+ # Ignore directories containing user credentials, local state, and settings.
173
+ # Learn more at https://abstra.io/docs
174
+ .abstra/
175
+
176
+ # Visual Studio Code
177
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
178
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
179
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
180
+ # you could uncomment the following to ignore the enitre vscode folder
181
+ # .vscode/
182
+
183
+ # Ruff stuff:
184
+ .ruff_cache/
185
+
186
+ # PyPI configuration file
187
+ .pypirc
188
+
189
+ # Cursor
190
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
191
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
192
+ # refer to https://docs.cursor.com/context/ignore-files
193
+ .cursorignore
194
+ .cursorindexingignore
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:24.04
2
+
3
+ WORKDIR /OCRFlux
4
+
5
+ ENV LANG=en_US.UTF-8 \
6
+ PIP_ROOT_USER_ACTION=ignore \
7
+ PIP_BREAK_SYSTEM_PACKAGES=true \
8
+ PIP_NO_CACHE_DIR=true \
9
+ PIP_DISABLE_PIP_VERSION_CHECK=true \
10
+ PYTHONPATH=/OCRFlux
11
+
12
+ SHELL ["/bin/bash", "-c"]
13
+
14
+ RUN --mount=type=bind,source=./,target=/builder \
15
+ cp -a /builder/. /OCRFlux/ && \
16
+ set -o pipefail && \
17
+ apt-get update && \
18
+ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
19
+ ca-certificates \
20
+ curl \
21
+ fonts-crosextra-caladea \
22
+ fonts-crosextra-carlito \
23
+ gsfonts \
24
+ lcdf-typetools \
25
+ locales \
26
+ msttcorefonts \
27
+ poppler-utils \
28
+ poppler-data \
29
+ python3.12-dev \
30
+ python3.12-full \
31
+ software-properties-common \
32
+ ttf-mscorefonts-installer && \
33
+ locale-gen en_US.UTF-8 && \
34
+ curl https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py && \
35
+ python3.12 /tmp/get-pip.py && \
36
+ python3.12 -m pip install . --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ && \
37
+ rm -rf ./* \
38
+ /var/lib/apt/lists/* \
39
+ /tmp/* \
40
+ /root/.cache/pip &&\
41
+ find /var/log /var/cache -type f -delete
42
+
43
+ ENTRYPOINT ["python3.12", "-m", "ocrflux.pipeline"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src="./images/OCRFlux.png" alt="OCRFlux Logo" width="300"/>
3
+ <hr/>
4
+ </div>
5
+ <p align="center">
6
+ <a href="https://github.com/chatdoc-com/OCRFlux/blob/main/LICENSE">
7
+ <img alt="GitHub License" src="./images/license.svg" height="20">
8
+ </a>
9
+ <a href="https://github.com/chatdoc-com/OCRFlux/releases">
10
+ <img alt="GitHub release" src="./images/release.svg" height="20">
11
+ </a>
12
+ <a href="https://ocrflux.pdfparser.io/">
13
+ <img alt="Demo" src="./images/demo.svg" height="20">
14
+ </a>
15
+ <a href="https://discord.gg/F33mhsAqqg">
16
+ <img alt="Discord" src="./images/discord.svg" height="20">
17
+ </a>
18
+ </p>
19
+
20
+ OCRFlux is a multimodal large language model based toolkit for converting PDFs and images into clean, readable, plain Markdown text. It aims to push the current state-of-the-art to a significantly higher level.
21
+
22
+ Try the online demo: [OCRFlux Demo](https://ocrflux.pdfparser.io/)
23
+
24
+ Functions: **Whole file parsing**
25
+ - On each page
26
+ - Convert into text with a natural reading order, even in the presence of multi-column layouts, figures, and insets
27
+ - Support for complicated tables and equations
28
+ - Automatically removes headers and footers
29
+
30
+ - Cross-page table/paragraph merging
31
+ - Cross-page table merging
32
+ - Cross-page paragraph merging
33
+
34
+
35
+ Key features:
36
+ - Superior parsing quality on each page
37
+
38
+ It respectively achieves 0.095 higher (from 0.872 to 0.967), 0.109 higher (from 0.858 to 0.967) and 0.187 higher (from 0.780 to 0.967) Edit Distance Similarity (EDS) on our released benchmark [OCRFlux-bench-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-single) than the baseline model [olmOCR-7B-0225-preview](https://huggingface.co/allenai/olmOCR-7B-0225-preview), [Nanonets-OCR-s](https://huggingface.co/nanonets/Nanonets-OCR-s) and [MonkeyOCR](https://huggingface.co/echo840/MonkeyOCR).
39
+
40
+ - Native support for cross-page table/paragraph merging (to our best this is the first to support this feature in all the open sourced project).
41
+
42
+ - Based on a 3B parameter VLM, so it can run even on GTX 3090 GPU.
43
+
44
+ Release:
45
+ - [OCRFlux-3B](https://huggingface.co/ChatDOC/OCRFlux-3B) - 3B parameter VLM
46
+ - Benchmark for evaluation
47
+ - [OCRFlux-bench-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-single)
48
+ - [OCRFlux-pubtabnet-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-single)
49
+ - [OCRFlux-bench-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-cross)
50
+ - [OCRFlux-pubtabnet-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-cross)
51
+
52
+
53
+ ### News
54
+ - Jun 17, 2025 - v0.1.0 - Initial public launch and demo.
55
+
56
+ ### Benchmark for single-page parsing
57
+
58
+ We ship two comprehensive benchmarks to help measure the performance of our OCR system in single-page parsing:
59
+
60
+ - [OCRFlux-bench-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-single): Containing 2000 pdf pages (1000 English pages and 1000 Chinese pages) and their ground-truth Markdowns (manually labeled with multi-round check).
61
+
62
+ - [OCRFlux-pubtabnet-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-single): Derived from the public [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet) benchmark with some format transformation. It contains 9064 HTML table samples, which are split into simple tables and complex tables according to whether they have rowspan and colspan cells.
63
+
64
+ We emphasize that the released benchmarks are NOT included in our training and evaluation data. The following is the main result:
65
+
66
+
67
+ 1. In [OCRFlux-bench-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-single), we calculated the Edit Distance Similarity (EDS) between the generated Markdowns and the ground-truth Markdowns as the metric.
68
+
69
+ <table>
70
+ <thead>
71
+ <tr>
72
+ <th>Language</th>
73
+ <th>Model</th>
74
+ <th>Avg EDS ↑</th>
75
+ </tr>
76
+ </thead>
77
+ <tbody>
78
+ <tr>
79
+ <td rowspan="4">English</td>
80
+ <td>olmOCR-7B-0225-preview</td>
81
+ <td>0.885</td>
82
+ </tr>
83
+ <tr>
84
+ <td>Nanonets-OCR-s</td>
85
+ <td>0.870</td>
86
+ </tr>
87
+ <tr>
88
+ <td>MonkeyOCR</td>
89
+ <td>0.828</td>
90
+ </tr>
91
+ <tr>
92
+ <td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
93
+ <td>0.971</td>
94
+ </tr>
95
+ <tr>
96
+ <td rowspan="4">Chinese</td>
97
+ <td>olmOCR-7B-0225-preview</td>
98
+ <td>0.859</td>
99
+ </tr>
100
+ <tr>
101
+ <td>Nanonets-OCR-s</td>
102
+ <td>0.846</td>
103
+ </tr>
104
+ <tr>
105
+ <td>MonkeyOCR</td>
106
+ <td>0.731</td>
107
+ </tr>
108
+ <tr>
109
+ <td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
110
+ <td>0.962</td>
111
+ </tr>
112
+ <tr>
113
+ <td rowspan="4">Total</td>
114
+ <td>olmOCR-7B-0225-preview</td>
115
+ <td>0.872</td>
116
+ </tr>
117
+ <tr>
118
+ <td>Nanonets-OCR-s</td>
119
+ <td>0.858</td>
120
+ </tr>
121
+ <tr>
122
+ <td>MonkeyOCR</td>
123
+ <td>0.780</td>
124
+ </tr>
125
+ <tr>
126
+ <td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
127
+ <td>0.967</td>
128
+ </tr>
129
+ </tbody>
130
+ </table>
131
+
132
+ 2. In [OCRFlux-pubtabnet-single](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-single), we calculated the Tree Edit Distance-based Similarity (TEDS) between the generated HTML tables and the ground-truth HTML tables as the metric.
133
+ <table>
134
+ <thead>
135
+ <tr>
136
+ <th>Type</th>
137
+ <th>Model</th>
138
+ <th>Avg TEDS ↑</th>
139
+ </tr>
140
+ </thead>
141
+ <tbody>
142
+ <tr>
143
+ <td rowspan="4">Simple</td>
144
+ <td>olmOCR-7B-0225-preview</td>
145
+ <td>0.810</td>
146
+ </tr>
147
+ <tr>
148
+ <td>Nanonets-OCR-s</td>
149
+ <td>0.882</td>
150
+ </tr>
151
+ <tr>
152
+ <td>MonkeyOCR</td>
153
+ <td>0.880</td>
154
+ </tr>
155
+ <tr>
156
+ <td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
157
+ <td>0.912</td>
158
+ </tr>
159
+ <tr>
160
+ <td rowspan="4">Complex</td>
161
+ <td>olmOCR-7B-0225-preview</td>
162
+ <td>0.676</td>
163
+ </tr>
164
+ <tr>
165
+ <td>Nanonets-OCR-s</td>
166
+ <td>0.772</td>
167
+ </tr>
168
+ <tr>
169
+ <td><strong>MonkeyOCR<strong></td>
170
+ <td>0.826</td>
171
+ </tr>
172
+ <tr>
173
+ <td><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></td>
174
+ <td>0.807</td>
175
+ </tr>
176
+ <tr>
177
+ <td rowspan="4">Total</td>
178
+ <td>olmOCR-7B-0225-preview</td>
179
+ <td>0.744</td>
180
+ </tr>
181
+ <tr>
182
+ <td>Nanonets-OCR-s</td>
183
+ <td>0.828</td>
184
+ </tr>
185
+ <tr>
186
+ <td>MonkeyOCR</td>
187
+ <td>0.853</td>
188
+ </tr>
189
+ <tr>
190
+ <td><strong><a href="https://huggingface.co/ChatDOC/OCRFlux-3B">OCRFlux-3B</a></strong></td>
191
+ <td>0.861</td>
192
+ </tr>
193
+ </tbody>
194
+ </table>
195
+
196
+ We also conduct some case studies to show the superiority of our model in the [blog](https://ocrflux.pdfparser.io/#/blog) article.
197
+
198
+ ### Benchmark for cross-page table/paragraph merging
199
+
200
+ PDF documents are typically paginated, which often results in tables or paragraphs being split across consecutive pages. Accurately detecting and merging such cross-page structures is crucial to avoid generating incomplete or fragmented content.
201
+
202
+ The detection task can be formulated as follows: given the Markdowns of two consecutive pages—each structured as a list of Markdown elements (e.g., paragraphs and tables)—the goal is to identify the indexes of elements that should be merged across the pages.
203
+
204
+ Then for the merging task, if the elements to be merged are paragraphs, we can just concate them. However, for two table fragments, their merging is much more challenging. For example, the table spanning multiple pages will repeat the header of the first page on the second page. Another difficult scenario is that the table cell contains long content that spans multiple lines within the cell, with the first few lines appearing on the previous page and the remaining lines continuing on the next page. We also observe some cases where tables with a large number of columns are split vertically and placed on two consecutive pages. More examples of cross-page tables can be found in our [blog](https://ocrflux.pdfparser.io/#/blog) article. To address these issues, we develop the LLM model for cross-page table merging. Specifically, this model takes two split table fragments as input and generates a complete, well-structured table as output.
205
+
206
+ We ship two comprehensive benchmarks to help measure the performance of our OCR system in cross-page table/paragraph detection and merging tasks respectively:
207
+
208
+ - [OCRFlux-bench-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-cross): Containing 1000 samples (500 English samples and 500 Chinese samples), each sample contains the Markdown element lists of two consecutive pages, along with the indexes of elements that need to be merged (manually labeled through multiple rounds of review). If no tables or paragraphs require merging, the indexes in the annotation data are left empty.
209
+
210
+ - [OCRFlux-pubtabnet-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-cross): Containing 9064 pairs of split table fragments, along with their corresponding ground-truth merged versions.
211
+
212
+ The released benchmarks are NOT included in our training and evaluation data neither. The following is the main result:
213
+
214
+ 1. In [OCRFlux-bench-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-bench-cross), we caculated the Accuracy, Precision, Recall and F1 score as the metric. Notice that the detection results are right only when it accurately judges whether there are elements that need to be merged across the two pages and output the right indexes of them.
215
+
216
+ | Language | Precision ↑ | Recall ↑ | F1 ↑ | Accuracy ↑ |
217
+ |----------|-------------|----------|-------|------------|
218
+ | English | 0.992 | 0.964 | 0.978 | 0.978 |
219
+ | Chinese | 1.000 | 0.988 | 0.994 | 0.994 |
220
+ | Total | 0.996 | 0.976 | 0.986 | 0.986 |
221
+
222
+ 2. In [OCRFlux-pubtabnet-cross](https://huggingface.co/datasets/ChatDOC/OCRFlux-pubtabnet-cross), we calculate the Tree Edit Distance-based Similarity (TEDS) between the generated merged table and the ground-truth merged table as the metric.
223
+
224
+ | Table type | Avg TEDS ↑ |
225
+ |------------|--------------|
226
+ | Simple | 0.965 |
227
+ | Complex | 0.935 |
228
+ | Total | 0.950 |
229
+
230
+ ### Installation
231
+
232
+ Requirements:
233
+ - Recent NVIDIA GPU (tested on RTX 3090, 4090, L40S, A100, H100) with at least 12 GB of GPU RAM
234
+ - 20GB of free disk space
235
+
236
+ You will need to install poppler-utils and additional fonts for rendering PDF images.
237
+
238
+ Install dependencies (Ubuntu/Debian)
239
+ ```bash
240
+ sudo apt-get update
241
+ sudo apt-get install poppler-utils poppler-data ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
242
+ ```
243
+
244
+ Set up a conda environment and install OCRFlux. The requirements for running OCRFlux
245
+ are difficult to install in an existing python environment, so please do make a clean python environment to install into.
246
+ ```bash
247
+ conda create -n ocrflux python=3.11
248
+ conda activate ocrflux
249
+
250
+ git clone https://github.com/chatdoc-com/OCRFlux.git
251
+ cd ocrflux
252
+
253
+ pip install -e . --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
254
+ ```
255
+
256
+ ### Local Usage Example
257
+
258
+ For quick testing, try the [web demo](https://5f65ccdc2d4fd2f364.gradio.live). To run locally, a GPU is required, as inference is powered by [vllm](hhttps://github.com/vllm-project/vllm) under the hood.
259
+
260
+ - For a pdf document:
261
+ ```bash
262
+ python -m ocrflux.pipeline ./localworkspace --data test.pdf --model /model_dir/OCRFlux-3B
263
+ ```
264
+
265
+ - For an image:
266
+ ```bash
267
+ python -m ocrflux.pipeline ./localworkspace --data test_page.png --model /model_dir/OCRFlux-3B
268
+ ```
269
+
270
+ - For a directory of pdf or images:
271
+ ```bash
272
+ python -m ocrflux.pipeline ./localworkspace --data test_pdf_dir/* --model /model_dir/OCRFlux-3B
273
+ ```
274
+ You can set `--skip_cross_page_merge` to skip the cross-page merging in the parsing process to accelerate, it would simply concatenate the parsing results of each page to generate final Markdown of the document.
275
+
276
+ Results will be stored as JSONL files in the `./localworkspace/results` directory.
277
+
278
+ Each line in JSONL files is a json object with the following fields:
279
+
280
+ ```
281
+ {
282
+ "orig_path": str, # the path to the raw pdf or image file
283
+ "num_pages": int, # the number of pages in the pdf file
284
+ "document_text": str, # the Markdown text of the converted pdf or image file
285
+ "page_texts": dict, # the Markdown texts of each page in the pdf file, the key is the page index and the value is the Markdown text of the page
286
+ "fallback_pages": [int], # the page indexes that are not converted successfully
287
+ }
288
+ ```
289
+
290
+ ### API for directly calling OCRFlux (New)
291
+ You can use the inference API to directly call OCRFlux in your codes without using an online vllm server like following:
292
+
293
+ ```
294
+ from vllm import LLM
295
+ from ocrflux.inference import parse
296
+
297
+ file_path = 'test.pdf'
298
+ # file_path = 'test.png'
299
+ llm = LLM(model="model_dir/OCRFlux-3B",gpu_memory_utilization=0.8,max_model_len=8192)
300
+ result = parse(llm,file_path)
301
+ if result != None:
302
+ document_markdown = result['document_text']
303
+ print(document_markdown)
304
+ with open('test.md','w') as f:
305
+ f.write(document_markdown)
306
+ else:
307
+ print("Parse failed.")
308
+ ```
309
+ If parsing is failed or there are fallback pages in the result, you can try to set the argument `max_page_retries` for the `parse` function with a positive integer to get a better result. But it may cause longer inference time.
310
+
311
+ ### Docker Usage
312
+
313
+ Requirements:
314
+
315
+ - Docker with GPU support [(NVIDIA Toolkit)](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
316
+ - Pre-downloaded model: [OCRFlux-3B](https://huggingface.co/ChatDOC/OCRFlux-3B)
317
+
318
+ To use OCRFlux in a docker container, you can use the following example command:
319
+
320
+ ```bash
321
+ docker run -it --gpus all \
322
+ -v /path/to/localworkspace:/localworkspace \
323
+ -v /path/to/test_pdf_dir:/test_pdf_dir/ \
324
+ -v /path/to/OCRFlux-3B:/OCRFlux-3B \
325
+ chatdoc/ocrflux:latest /localworkspace --data /test_pdf_dir/* --model /OCRFlux-3B/
326
+ ```
327
+
328
+ #### Viewing Results
329
+ Generate the final Markdown files by running the following command. Generated Markdown files will be in `./localworkspace/markdowns/DOCUMENT_NAME` directory.
330
+
331
+ ```bash
332
+ python -m ocrflux.jsonl_to_markdown ./localworkspace
333
+ ```
334
+
335
+ ### Full documentation for the pipeline
336
+
337
+ ```bash
338
+ python -m ocrflux.pipeline --help
339
+ usage: pipeline.py [-h] [--task {pdf2markdown,merge_pages,merge_tables}] [--data [DATA ...]] [--pages_per_group PAGES_PER_GROUP] [--max_page_retries MAX_PAGE_RETRIES]
340
+ [--max_page_error_rate MAX_PAGE_ERROR_RATE] [--workers WORKERS] [--model MODEL] [--model_max_context MODEL_MAX_CONTEXT] [--model_chat_template MODEL_CHAT_TEMPLATE]
341
+ [--target_longest_image_dim TARGET_LONGEST_IMAGE_DIM] [--skip_cross_page_merge] [--port PORT]
342
+ workspace
343
+
344
+ Manager for running millions of PDFs through a batch inference pipeline
345
+
346
+ positional arguments:
347
+ workspace The filesystem path where work will be stored, can be a local folder
348
+
349
+ options:
350
+ -h, --help show this help message and exit
351
+ --data [DATA ...] List of paths to files to process
352
+ --pages_per_group PAGES_PER_GROUP
353
+ Aiming for this many pdf pages per work item group
354
+ --max_page_retries MAX_PAGE_RETRIES
355
+ Max number of times we will retry rendering a page
356
+ --max_page_error_rate MAX_PAGE_ERROR_RATE
357
+ Rate of allowable failed pages in a document, 1/250 by default
358
+ --workers WORKERS Number of workers to run at a time
359
+ --model MODEL The path to the model
360
+ --model_max_context MODEL_MAX_CONTEXT
361
+ Maximum context length that the model was fine tuned under
362
+ --model_chat_template MODEL_CHAT_TEMPLATE
363
+ Chat template to pass to vllm server
364
+ --target_longest_image_dim TARGET_LONGEST_IMAGE_DIM
365
+ Dimension on longest side to use for rendering the pdf pages
366
+ --skip_cross_page_merge
367
+ Whether to skip cross-page merging
368
+ --port PORT Port to use for the VLLM server
369
+ ```
370
+
371
+ ## Code overview
372
+
373
+ There are some nice reusable pieces of the code that may be useful for your own projects:
374
+ - Processing millions of PDFs through our released model using VLLM - [pipeline.py](https://github.com/chatdoc-com/OCRFlux/blob/main/ocrflux/pipeline.py)
375
+ - Generating final Markdowns from jsonl files - [jsonl_to_markdown.py](https://github.com/chatdoc-com/OCRFlux/blob/main/ocrflux/jsonl_to_markdown.py)
376
+ - Evaluating the model on the single-page parsing task - [eval_page_to_markdown.py](https://github.com/chatdoc-com/OCRFlux/blob/main/eval/eval_page_to_markdown.py)
377
+ - Evaluating the model on the table parising task - [eval_table_to_html.py](https://github.com/chatdoc-com/OCRFlux/blob/main/eval/eval_table_to_html.py)
378
+ - Evaluating the model on the paragraphs/tables merging detection task - [eval_element_merge_detect.py](https://github.com/chatdoc-com/OCRFlux/blob/main/eval/eval_element_merge_detect.py)
379
+ - Evaluating the model on the table merging task - [eval_html_table_merge.py](https://github.com/chatdoc-com/OCRFlux/blob/main/eval/eval_html_table_merge.py)
380
+
381
+
382
+ ## Team
383
+
384
+ <!-- start team -->
385
+
386
+ **OCRFlux** is developed and maintained by the ChatDOC team, backed by [ChatDOC](https://chatdoc.com/).
387
+
388
+ <!-- end team -->
389
+
390
+ ## License
391
+
392
+ <!-- start license -->
393
+
394
+ **OCRFlux** is licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0).
395
+ A full copy of the license can be found [on GitHub](https://github.com/allenai/OCRFlux/blob/main/LICENSE).
396
+
397
+ <!-- end license -->
eval/eval.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # page_to_markdown task
4
+ python -m ocrflux.pipeline ./eval_page_to_markdown_result --task pdf2markdown --data /data/OCRFlux-bench-single/pdfs/*.pdf --model /data/OCRFlux-7B
5
+
6
+ python -m eval.eval_page_to_markdown ./eval_page_to_markdown_result --gt_file /data/OCRFlux-bench-single/data.jsonl
7
+
8
+ # element_merge_detect task
9
+ python -m eval.gen_element_merge_detect_data /data/OCRFlux-bench-cross
10
+
11
+ python -m ocrflux.pipeline ./eval_element_merge_detect_result --task merge_pages --data /data/OCRFlux-bench-cross/jsons/*.json --model /data/OCRFlux-7B
12
+
13
+ python -m eval.eval_element_merge_detect ./eval_element_merge_detect_result --gt_file /data/OCRFlux-bench-cross/data.jsonl
14
+
15
+ # table_to_html task
16
+ python -m ocrflux.pipeline ./eval_table_to_html_result --task pdf2markdown --data /data/OCRFlux-pubtabnet-single/images/*.png --model /data/OCRFlux-7B
17
+
18
+ python -m eval.eval_table_to_html ./eval_table_to_html_result --gt_file /data/OCRFlux-pubtabnet-single/data.jsonl
19
+
20
+ # html_table_merge task
21
+ python -m eval.gen_html_table_merge_data /data/OCRFlux-pubtabnet-cross
22
+
23
+ python -m ocrflux.pipeline ./eval_html_table_merge_result --task merge_tables --data /data/OCRFlux-pubtabnet-cross/jsons/*.json --model /data/OCRFlux-7B
24
+
25
+ python -m eval.eval_html_table_merge ./eval_html_table_merge_result --gt_file /data/OCRFlux-pubtabnet-cross/data.jsonl
26
+
27
+
eval/eval_element_merge_detect.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import nltk
5
+ from tqdm import tqdm
6
+ from eval.parallel import parallel_process
7
+
8
+
9
+ def evaluate(pred, gt):
10
+ pred = sorted(pred, key=lambda x: (x[0], x[1]))
11
+ gt = sorted(gt, key=lambda x: (x[0], x[1]))
12
+ if pred == gt:
13
+ return 1
14
+ else:
15
+ return 0
16
+
17
+ def main():
18
+ parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task")
19
+ parser.add_argument(
20
+ "workspace",
21
+ help="The filesystem path where work will be stored, can be a local folder",
22
+ )
23
+ parser.add_argument(
24
+ "--gt_file",
25
+ help="Ground truth file",
26
+ )
27
+ parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
28
+ args = parser.parse_args()
29
+
30
+ pred_data = {}
31
+ root_dir = os.path.join(args.workspace, "results")
32
+ for jsonl_file in os.listdir(root_dir):
33
+ if jsonl_file.endswith(".jsonl"):
34
+ with open(os.path.join(root_dir, jsonl_file), "r") as f:
35
+ for line in f:
36
+ data = json.loads(line)
37
+ pred_data[os.path.basename(data['orig_path'])] = data['merge_pairs']
38
+
39
+ filename_list_en = []
40
+ filename_list_zh = []
41
+ gt_data = {}
42
+ with open(args.gt_file, "r") as f:
43
+ for line in f:
44
+ data = json.loads(line)
45
+ pdf_name_1 = data['pdf_name_1'].split(".")[0]
46
+ pdf_name_2 = data['pdf_name_2'].split(".")[0]
47
+
48
+ pdf_name,page_1 = pdf_name_1.split('_')
49
+ pdf_name,page_2 = pdf_name_2.split('_')
50
+
51
+ json_name = pdf_name + '_' + page_1 + '_' + page_2 + '.json'
52
+ gt_data[json_name] = data['merging_idx_pairs']
53
+
54
+ if data['language'] == 'en':
55
+ filename_list_en.append(json_name)
56
+ else:
57
+ filename_list_zh.append(json_name)
58
+
59
+ keys = list(gt_data.keys())
60
+ if args.n_jobs == 1:
61
+ scores = [evaluate(pred_data.get(filename, []), gt_data.get(filename, [])) for filename in tqdm(keys)]
62
+ else:
63
+ inputs = [{'pred': pred_data.get(filename, []), 'gt': gt_data.get(filename, [])} for filename in keys]
64
+ scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
65
+
66
+ tp_en = 0
67
+ tn_en = 0
68
+ fp_en = 0
69
+ fn_en = 0
70
+ tp_zh = 0
71
+ tn_zh = 0
72
+ fp_zh = 0
73
+ fn_zh = 0
74
+ score_en = 0
75
+ score_zh = 0
76
+ num_en = 0
77
+ num_zh = 0
78
+ for filename, score in zip(keys, scores):
79
+ print(filename)
80
+ print(score)
81
+ print()
82
+ pred_label = pred_data[filename]
83
+ if filename in filename_list_en:
84
+ if pred_label == []:
85
+ if score == 1:
86
+ tn_en += 1
87
+ else:
88
+ fn_en += 1
89
+ else:
90
+ if score == 1:
91
+ tp_en += 1
92
+ else:
93
+ fp_en += 1
94
+ score_en += score
95
+ num_en += 1
96
+
97
+ elif filename in filename_list_zh:
98
+ if pred_label == []:
99
+ if score == 1:
100
+ tn_zh += 1
101
+ else:
102
+ fn_zh += 1
103
+ else:
104
+ if score == 1:
105
+ tp_zh += 1
106
+ else:
107
+ fp_zh += 1
108
+ score_zh += score
109
+ num_zh += 1
110
+
111
+ precision_en = tp_en / (tp_en + fp_en)
112
+ recall_en = tp_en / (tp_en + fn_en)
113
+ f1_en = 2*precision_en*recall_en / (precision_en+recall_en)
114
+ acc_en = score_en / num_en
115
+
116
+ precision_zh = tp_zh / (tp_zh + fp_zh)
117
+ recall_zh = tp_zh / (tp_zh + fn_zh)
118
+ f1_zh = 2*precision_zh*recall_zh / (precision_zh+recall_zh)
119
+ acc_zh = score_zh / num_zh
120
+
121
+ tp = tp_en + tp_zh
122
+ fp = fp_en + fp_zh
123
+ fn = fn_en + fn_zh
124
+ score = score_en + score_zh
125
+ num = num_en + num_zh
126
+
127
+ precision = tp / (tp + fp)
128
+ recall = tp / (tp + fn)
129
+ f1 = 2*precision*recall / (precision+recall)
130
+ acc = score / num
131
+
132
+ print(f"EN: {precision_en} / {recall_en} / {f1_en} / {acc_en}")
133
+ print(f"ZH: {precision_zh} / {recall_zh} / {f1_zh} / {acc_zh}")
134
+ print(f"ALL: {precision} / {recall} / {f1} / {acc}")
135
+
136
+ if __name__ == "__main__":
137
+ main()
eval/eval_html_table_merge.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import distance
5
+ from apted import APTED, Config
6
+ from apted.helpers import Tree
7
+ from lxml import etree, html
8
+ from collections import deque
9
+ from tqdm import tqdm
10
+ from eval.parallel import parallel_process
11
+
12
+
13
+ class TableTree(Tree):
14
+ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
15
+ self.tag = tag
16
+ self.colspan = colspan
17
+ self.rowspan = rowspan
18
+ self.content = content
19
+ self.children = list(children)
20
+
21
+ def bracket(self):
22
+ """Show tree using brackets notation"""
23
+ if self.tag == 'td':
24
+ result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
25
+ (self.tag, self.colspan, self.rowspan, self.content)
26
+ else:
27
+ result = '"tag": %s' % self.tag
28
+ for child in self.children:
29
+ result += child.bracket()
30
+ return "{{{}}}".format(result)
31
+
32
+
33
+ class CustomConfig(Config):
34
+ @staticmethod
35
+ def maximum(*sequences):
36
+ """Get maximum possible value
37
+ """
38
+ return max(map(len, sequences))
39
+
40
+ def normalized_distance(self, *sequences):
41
+ """Get distance from 0 to 1
42
+ """
43
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
44
+
45
+ def rename(self, node1, node2):
46
+ """Compares attributes of trees"""
47
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
48
+ return 1.
49
+ if node1.tag == 'td':
50
+ if node1.content or node2.content:
51
+ return self.normalized_distance(node1.content, node2.content)
52
+ return 0.
53
+
54
+
55
+ class TEDS(object):
56
+ ''' Tree Edit Distance basead Similarity
57
+ '''
58
+ def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
59
+ assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
60
+ self.structure_only = structure_only
61
+ self.n_jobs = n_jobs
62
+ self.ignore_nodes = ignore_nodes
63
+ self.__tokens__ = []
64
+
65
+ def tokenize(self, node):
66
+ ''' Tokenizes table cells
67
+ '''
68
+ self.__tokens__.append('<%s>' % node.tag)
69
+ if node.text is not None:
70
+ self.__tokens__ += list(node.text)
71
+ for n in node.getchildren():
72
+ self.tokenize(n)
73
+ if node.tag != 'unk':
74
+ self.__tokens__.append('</%s>' % node.tag)
75
+ if node.tag != 'td' and node.tail is not None:
76
+ self.__tokens__ += list(node.tail)
77
+
78
+ def load_html_tree(self, node, parent=None):
79
+ ''' Converts HTML tree to the format required by apted
80
+ '''
81
+ global __tokens__
82
+ if node.tag == 'td':
83
+ if self.structure_only:
84
+ cell = []
85
+ else:
86
+ self.__tokens__ = []
87
+ self.tokenize(node)
88
+ cell = self.__tokens__[1:-1].copy()
89
+ new_node = TableTree(node.tag,
90
+ int(node.attrib.get('colspan', '1')),
91
+ int(node.attrib.get('rowspan', '1')),
92
+ cell, *deque())
93
+ else:
94
+ new_node = TableTree(node.tag, None, None, None, *deque())
95
+ if parent is not None:
96
+ parent.children.append(new_node)
97
+ if node.tag != 'td':
98
+ for n in node.getchildren():
99
+ self.load_html_tree(n, new_node)
100
+ if parent is None:
101
+ return new_node
102
+
103
+ def evaluate(self, pred, true):
104
+ ''' Computes TEDS score between the prediction and the ground truth of a
105
+ given sample
106
+ '''
107
+ if (not pred) or (not true):
108
+ return 0.0
109
+ pred = "<html>" + pred + "</html>"
110
+ true = "<html>" + true + "</html>"
111
+ parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
112
+ pred = html.fromstring(pred, parser=parser)
113
+ true = html.fromstring(true, parser=parser)
114
+ if pred.xpath('body/table') and true.xpath('body/table'):
115
+ pred = pred.xpath('body/table')[0]
116
+ true = true.xpath('body/table')[0]
117
+ if self.ignore_nodes:
118
+ etree.strip_tags(pred, *self.ignore_nodes)
119
+ etree.strip_tags(true, *self.ignore_nodes)
120
+ n_nodes_pred = len(pred.xpath(".//*"))
121
+ n_nodes_true = len(true.xpath(".//*"))
122
+ n_nodes = max(n_nodes_pred, n_nodes_true)
123
+ tree_pred = self.load_html_tree(pred)
124
+ tree_true = self.load_html_tree(true)
125
+ distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
126
+ return 1.0 - (float(distance) / n_nodes)
127
+ else:
128
+ return 0.0
129
+
130
+ def batch_evaluate(self, pred_json, true_json):
131
+ ''' Computes TEDS score between the prediction and the ground truth of
132
+ a batch of samples
133
+ @params pred_json: {'FILENAME': 'HTML CODE', ...}
134
+ @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
135
+ @output: {'FILENAME': 'TEDS SCORE', ...}
136
+ '''
137
+ samples = true_json.keys()
138
+ if self.n_jobs == 1:
139
+ scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
140
+ else:
141
+ inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
142
+ scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
143
+ total_score_simple = 0
144
+ num_simple = 0
145
+ total_score_complex = 0
146
+ num_complex = 0
147
+ total_score = 0
148
+ num_total = 0
149
+ for filename,score in zip(samples, scores):
150
+ print(filename)
151
+ print(score)
152
+ print('')
153
+ if true_json[filename]['type'] == 'simple':
154
+ total_score_simple += score
155
+ num_simple += 1
156
+ elif true_json[filename]['type'] == 'complex':
157
+ total_score_complex += score
158
+ num_complex += 1
159
+ else:
160
+ raise ValueError('Unknown type: %s' % true_json[filename]['type'])
161
+ total_score += score
162
+ num_total += 1
163
+ if num_simple > 0:
164
+ avg_score_simple = total_score_simple / num_simple
165
+ else:
166
+ avg_score_simple = 0
167
+ if num_complex > 0:
168
+ avg_score_complex = total_score_complex / num_complex
169
+ else:
170
+ avg_score_complex = 0
171
+ avg_score = total_score / num_total
172
+ print({'simple': (num_simple,avg_score_simple), 'complex': (num_complex,avg_score_complex), 'total': (num_total,avg_score)})
173
+
174
+ def main():
175
+ parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
176
+ parser.add_argument(
177
+ "workspace",
178
+ help="The filesystem path where work will be stored, can be a local folder",
179
+ )
180
+ parser.add_argument(
181
+ "--gt_file",
182
+ help="Ground truth file",
183
+ )
184
+ parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
185
+ args = parser.parse_args()
186
+
187
+ pred_data = {}
188
+ root_dir = os.path.join(args.workspace, "results")
189
+ for jsonl_file in os.listdir(root_dir):
190
+ if jsonl_file.endswith(".jsonl"):
191
+ with open(os.path.join(root_dir, jsonl_file), "r") as f:
192
+ for line in f:
193
+ data = json.loads(line)
194
+ key = os.path.basename(data['orig_path']).split('.')[0]
195
+ pred_data[key] = data['merged_tables']
196
+
197
+ gt_data = {}
198
+ with open(args.gt_file, "r") as f:
199
+ for line in f:
200
+ data = json.loads(line)
201
+ key = data['image_name'].split('.')[0]
202
+ gt_data[key] = {'html':data['gt_table'], 'type':data['type']}
203
+
204
+ teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
205
+ teds.batch_evaluate(pred_data, gt_data)
206
+
207
+ if __name__ == "__main__":
208
+ main()
eval/eval_page_to_markdown.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import nltk
5
+ from tqdm import tqdm
6
+ from eval.parallel import parallel_process
7
+
8
+
9
+ def evaluate(pred, gt):
10
+ edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
11
+ return 1.0 - edit_dist
12
+
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
16
+ parser.add_argument(
17
+ "workspace",
18
+ help="The filesystem path where work will be stored, can be a local folder",
19
+ )
20
+ parser.add_argument(
21
+ "--gt_file",
22
+ help="Ground truth file",
23
+ )
24
+ parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
25
+ args = parser.parse_args()
26
+
27
+ pred_data = {}
28
+ root_dir = os.path.join(args.workspace, "results")
29
+ for jsonl_file in os.listdir(root_dir):
30
+ if jsonl_file.endswith(".jsonl"):
31
+ with open(os.path.join(root_dir, jsonl_file), "r") as f:
32
+ for line in f:
33
+ data = json.loads(line)
34
+ pred_data[os.path.basename(data['orig_path'])] = data['document_text']
35
+
36
+ filename_list_en = []
37
+ filename_list_zh = []
38
+ gt_data = {}
39
+ with open(args.gt_file, "r") as f:
40
+ for line in f:
41
+ data = json.loads(line)
42
+ markdown = data['markdown']
43
+ pdf_name = data['pdf_name']
44
+ gt_data[pdf_name] = markdown
45
+ if data['language'] == 'en':
46
+ filename_list_en.append(pdf_name)
47
+ else:
48
+ filename_list_zh.append(pdf_name)
49
+
50
+ keys = list(gt_data.keys())
51
+ if args.n_jobs == 1:
52
+ scores = [evaluate(pred_data.get(filename, ''), gt_data.get(filename, '')) for filename in tqdm(keys)]
53
+ else:
54
+ inputs = [{'pred': pred_data.get(filename, ''), 'gt': gt_data.get(filename, '')} for filename in keys]
55
+ scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
56
+
57
+ total_score_en = 0
58
+ total_num_en = 0
59
+ total_score_zh = 0
60
+ total_num_zh = 0
61
+ for filename, score in zip(keys, scores):
62
+ print(filename)
63
+ print(score)
64
+ print()
65
+ if filename in filename_list_en:
66
+ total_score_en += score
67
+ total_num_en += 1
68
+ elif filename in filename_list_zh:
69
+ total_score_zh += score
70
+ total_num_zh += 1
71
+ print(f"English: {total_score_en / total_num_en}")
72
+ print(f"Chinese: {total_score_zh / total_num_zh}")
73
+ print(f"Total: {sum(scores) / len(scores)}")
74
+
75
+ if __name__ == "__main__":
76
+ main()
eval/eval_page_to_markdown_nanonets.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import argparse
5
+ import nltk
6
+ import markdown2
7
+ from bs4 import BeautifulSoup
8
+ from tqdm import tqdm
9
+ from eval.parallel import parallel_process
10
+
11
+ def turn_header_to_h1(line):
12
+ # 检查是否是以一个或多个 '#' 开头的标题行
13
+ if line.lstrip().startswith('#'):
14
+ # 去掉开头的 '#' 和其后的空格
15
+ new_line = "# " + line.lstrip().lstrip('#').lstrip()
16
+ return new_line
17
+ else:
18
+ return line
19
+
20
+ def replace_single_dollar(markdown_text):
21
+ pattern = r'\$(.*?)\$'
22
+ def replace_with_brackets(match):
23
+ formula_content = match.group(1) # 获取匹配到的公式内容
24
+ return f'\\({formula_content}\\)'
25
+
26
+ replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
27
+
28
+ return replaced_text
29
+
30
+
31
+ def replace_double_dollar(markdown_text):
32
+ pattern = r'\$\$(.*?)\$\$'
33
+ def replace_with_brackets(match):
34
+ formula_content = match.group(1)
35
+ return f'\\[{formula_content}\\]'
36
+ replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
37
+
38
+ return replaced_text
39
+
40
+ def simplify_html_table(html_table):
41
+ # 使用 BeautifulSoup 解析 HTML
42
+ soup = BeautifulSoup(html_table, 'html.parser')
43
+
44
+ # 找到 <table> 标签
45
+ table = soup.find('table')
46
+ if not table:
47
+ raise ValueError("输入的 HTML 不包含有效的 <table> 标签")
48
+
49
+ # 创建一个新的 <table> 标签
50
+ new_table = BeautifulSoup('<table></table>', 'html.parser').table
51
+
52
+ # 提取所有行(包括 <thead> 和 <tbody> 中的行)
53
+ rows = table.find_all(['tr'], recursive=True)
54
+
55
+ for row in rows:
56
+ # 创建新的 <tr> 标签
57
+ new_row = soup.new_tag('tr')
58
+
59
+ # 处理每一行中的单元格
60
+ cells = row.find_all(['th', 'td'])
61
+ for cell in cells:
62
+ # 将 <th> 替换为 <td>
63
+ new_cell = soup.new_tag('td')
64
+ if cell.has_attr('rowspan'):
65
+ new_cell['rowspan'] = cell['rowspan']
66
+ if cell.has_attr('colspan'):
67
+ new_cell['colspan'] = cell['colspan']
68
+ new_cell.string = cell.get_text(strip=True) # 保留单元格内容
69
+ new_row.append(new_cell)
70
+
71
+ # 将新行添加到新表格中
72
+ new_table.append(new_row)
73
+
74
+ # 返回简化后的表格 HTML
75
+ return str(new_table)
76
+
77
+ def evaluate(pred, gt):
78
+ edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
79
+ return 1.0- edit_dist
80
+
81
+
82
+ def main():
83
+ parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
84
+ parser.add_argument(
85
+ "workspace",
86
+ help="The filesystem path where work will be stored, can be a local folder",
87
+ )
88
+ parser.add_argument(
89
+ "--gt_file",
90
+ help="Ground truth file",
91
+ )
92
+ parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
93
+ args = parser.parse_args()
94
+
95
+ pred_data = {}
96
+ for file in os.listdir(args.workspace):
97
+ file_path = os.path.join(args.workspace, file)
98
+ pdf_name = file.split('.')[0] + ".pdf"
99
+ with open(file_path, "r") as f:
100
+ document_text = f.read()
101
+ document_text = replace_single_dollar(replace_double_dollar(document_text))
102
+ markdown_text_list = document_text.split("\n\n")
103
+ new_markdown_text_list = []
104
+ for text in markdown_text_list:
105
+ text = text.strip()
106
+ if (text.startswith("<watermark>") and text.endswith("</watermark>")) or (text.startswith("<img>") and text.endswith("</img>")) or (text.startswith("<page_number>") and text.endswith("</page_number>")) or (text.startswith("<signature>") and text.endswith("</signature>")):
107
+ continue
108
+ else:
109
+ html_text = str(markdown2.markdown(text,extras=["tables"]))
110
+ html_text = html_text.strip()
111
+ if html_text.startswith("<table>") and html_text.endswith("</table>"):
112
+ html_table = simplify_html_table(html_text)
113
+ new_markdown_text_list.append(html_table)
114
+ else:
115
+ text = turn_header_to_h1(text)
116
+ new_markdown_text_list.append(text)
117
+
118
+ pred_data[os.path.basename(pdf_name)] = "\n\n".join(new_markdown_text_list)
119
+
120
+ filename_list_en = []
121
+ filename_list_zh = []
122
+ gt_data = {}
123
+ with open(args.gt_file, "r") as f:
124
+ for line in f:
125
+ data = json.loads(line)
126
+ markdown = data['markdown']
127
+ pdf_name = data['pdf_name']
128
+ gt_data[pdf_name] = markdown
129
+ if data['language'] == 'en':
130
+ filename_list_en.append(pdf_name)
131
+ else:
132
+ filename_list_zh.append(pdf_name)
133
+
134
+ keys = list(gt_data.keys())
135
+ if args.n_jobs == 1:
136
+ scores = [evaluate(pred_data.get(filename, ''), gt_data.get(filename, '')) for filename in tqdm(keys)]
137
+ else:
138
+ inputs = [{'pred': pred_data.get(filename, ''), 'gt': gt_data.get(filename, '')} for filename in keys]
139
+ scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
140
+
141
+ total_score_en = 0
142
+ total_num_en = 0
143
+ total_score_zh = 0
144
+ total_num_zh = 0
145
+ for filename, score in zip(keys, scores):
146
+ if filename in filename_list_en:
147
+ print(filename)
148
+ print(score)
149
+ print()
150
+ total_score_en += score
151
+ total_num_en += 1
152
+ elif filename in filename_list_zh:
153
+ total_score_zh += score
154
+ total_num_zh += 1
155
+ print(f"English: {total_score_en / total_num_en}")
156
+ print(f"Chinese: {total_score_zh / total_num_zh}")
157
+ print(f"Total: {sum(scores) / len(scores)}")
158
+
159
+ if __name__ == "__main__":
160
+ main()
eval/eval_page_to_markdown_olmocr.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import argparse
5
+ import nltk
6
+ import markdown2
7
+ from bs4 import BeautifulSoup
8
+ from tqdm import tqdm
9
+ from eval.parallel import parallel_process
10
+
11
+ def turn_header_to_h1(line):
12
+ # 检查是否是以一个或多个 '#' 开头的标题行
13
+ if line.lstrip().startswith('#'):
14
+ # 去掉开头的 '#' 和其后的空格
15
+ new_line = "# " + line.lstrip().lstrip('#').lstrip()
16
+ return new_line
17
+ else:
18
+ return line
19
+
20
+ def replace_single_dollar(markdown_text):
21
+ pattern = r'\$(.*?)\$'
22
+ def replace_with_brackets(match):
23
+ formula_content = match.group(1) # 获取匹配到的公式内容
24
+ return f'\\({formula_content}\\)'
25
+
26
+ replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
27
+
28
+ return replaced_text
29
+
30
+
31
+ def replace_double_dollar(markdown_text):
32
+ pattern = r'\$\$(.*?)\$\$'
33
+ def replace_with_brackets(match):
34
+ formula_content = match.group(1)
35
+ return f'\\[{formula_content}\\]'
36
+ replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
37
+
38
+ return replaced_text
39
+
40
+ def simplify_html_table(html_table):
41
+ # 使用 BeautifulSoup 解析 HTML
42
+ soup = BeautifulSoup(html_table, 'html.parser')
43
+
44
+ # 找到 <table> 标签
45
+ table = soup.find('table')
46
+ if not table:
47
+ raise ValueError("输入的 HTML 不包含有效的 <table> 标签")
48
+
49
+ # 创建一个新的 <table> 标签
50
+ new_table = BeautifulSoup('<table></table>', 'html.parser').table
51
+
52
+ # 提取所有行(包括 <thead> 和 <tbody> 中的行)
53
+ rows = table.find_all(['tr'], recursive=True)
54
+
55
+ for row in rows:
56
+ # 创建新的 <tr> 标签
57
+ new_row = soup.new_tag('tr')
58
+
59
+ # 处理每一行中的单元格
60
+ cells = row.find_all(['th', 'td'])
61
+ for cell in cells:
62
+ # 将 <th> 替换为 <td>
63
+ new_cell = soup.new_tag('td')
64
+ new_cell.string = cell.get_text(strip=True) # 保留单元格内容
65
+ new_row.append(new_cell)
66
+
67
+ # 将新行添加到新表格中
68
+ new_table.append(new_row)
69
+
70
+ # 返回简化后的表格 HTML
71
+ return str(new_table)
72
+
73
+ def evaluate(pred, gt):
74
+ edit_dist = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
75
+ return 1.0- edit_dist
76
+
77
+
78
+ def main():
79
+ parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
80
+ parser.add_argument(
81
+ "workspace",
82
+ help="The filesystem path where work will be stored, can be a local folder",
83
+ )
84
+ parser.add_argument(
85
+ "--gt_file",
86
+ help="Ground truth file",
87
+ )
88
+ parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
89
+ args = parser.parse_args()
90
+
91
+ pred_data = {}
92
+ root_dir = os.path.join(args.workspace, "results")
93
+ for jsonl_file in os.listdir(root_dir):
94
+ if jsonl_file.endswith(".jsonl"):
95
+ with open(os.path.join(root_dir, jsonl_file), "r") as f:
96
+ for line in f:
97
+ data = json.loads(line)
98
+ pdf_path = data['metadata']['Source-File']
99
+ document_text = data['text']
100
+ document_text = replace_single_dollar(replace_double_dollar(document_text))
101
+
102
+ markdown_text_list = document_text.split("\n\n")
103
+
104
+ new_markdown_text_list = []
105
+ for text in markdown_text_list:
106
+ html_text = str(markdown2.markdown(text,extras=["tables"]))
107
+ html_text = html_text.strip()
108
+ if html_text.startswith("<table>") and html_text.endswith("</table>"):
109
+ html_table = simplify_html_table(html_text)
110
+ new_markdown_text_list.append(html_table)
111
+ else:
112
+ text = turn_header_to_h1(text)
113
+ new_markdown_text_list.append(text)
114
+
115
+ pred_data[os.path.basename(pdf_path)] = "\n\n".join(new_markdown_text_list)
116
+
117
+ filename_list_en = []
118
+ filename_list_zh = []
119
+ gt_data = {}
120
+ with open(args.gt_file, "r") as f:
121
+ for line in f:
122
+ data = json.loads(line)
123
+ markdown = data['markdown']
124
+ pdf_name = data['pdf_name']
125
+ gt_data[pdf_name] = markdown
126
+ if data['language'] == 'en':
127
+ filename_list_en.append(pdf_name)
128
+ else:
129
+ filename_list_zh.append(pdf_name)
130
+
131
+ keys = list(gt_data.keys())
132
+ if args.n_jobs == 1:
133
+ scores = [evaluate(pred_data.get(filename, ''), gt_data.get(filename, '')) for filename in tqdm(keys)]
134
+ else:
135
+ inputs = [{'pred': pred_data.get(filename, ''), 'gt': gt_data.get(filename, '')} for filename in keys]
136
+ scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
137
+
138
+ total_score_en = 0
139
+ total_num_en = 0
140
+ total_score_zh = 0
141
+ total_num_zh = 0
142
+ for filename, score in zip(keys, scores):
143
+ if filename in filename_list_en:
144
+ print(filename)
145
+ print(score)
146
+ print()
147
+ total_score_en += score
148
+ total_num_en += 1
149
+ elif filename in filename_list_zh:
150
+ total_score_zh += score
151
+ total_num_zh += 1
152
+ print(f"English: {total_score_en / total_num_en}")
153
+ print(f"Chinese: {total_score_zh / total_num_zh}")
154
+ print(f"Total: {sum(scores) / len(scores)}")
155
+
156
+ if __name__ == "__main__":
157
+ main()
eval/eval_table_to_html.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import distance
5
+ from apted import APTED, Config
6
+ from apted.helpers import Tree
7
+ from lxml import etree, html
8
+ from collections import deque
9
+ from tqdm import tqdm
10
+ from eval.parallel import parallel_process
11
+
12
+
13
+ class TableTree(Tree):
14
+ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
15
+ self.tag = tag
16
+ self.colspan = colspan
17
+ self.rowspan = rowspan
18
+ self.content = content
19
+ self.children = list(children)
20
+
21
+ def bracket(self):
22
+ """Show tree using brackets notation"""
23
+ if self.tag == 'td':
24
+ result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
25
+ (self.tag, self.colspan, self.rowspan, self.content)
26
+ else:
27
+ result = '"tag": %s' % self.tag
28
+ for child in self.children:
29
+ result += child.bracket()
30
+ return "{{{}}}".format(result)
31
+
32
+
33
+ class CustomConfig(Config):
34
+ @staticmethod
35
+ def maximum(*sequences):
36
+ """Get maximum possible value
37
+ """
38
+ return max(map(len, sequences))
39
+
40
+ def normalized_distance(self, *sequences):
41
+ """Get distance from 0 to 1
42
+ """
43
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
44
+
45
+ def rename(self, node1, node2):
46
+ """Compares attributes of trees"""
47
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
48
+ return 1.
49
+ if node1.tag == 'td':
50
+ if node1.content or node2.content:
51
+ return self.normalized_distance(node1.content, node2.content)
52
+ return 0.
53
+
54
+
55
+ class TEDS(object):
56
+ ''' Tree Edit Distance basead Similarity
57
+ '''
58
+ def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
59
+ assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
60
+ self.structure_only = structure_only
61
+ self.n_jobs = n_jobs
62
+ self.ignore_nodes = ignore_nodes
63
+ self.__tokens__ = []
64
+
65
+ def tokenize(self, node):
66
+ ''' Tokenizes table cells
67
+ '''
68
+ self.__tokens__.append('<%s>' % node.tag)
69
+ if node.text is not None:
70
+ self.__tokens__ += list(node.text)
71
+ for n in node.getchildren():
72
+ self.tokenize(n)
73
+ if node.tag != 'unk':
74
+ self.__tokens__.append('</%s>' % node.tag)
75
+ if node.tag != 'td' and node.tail is not None:
76
+ self.__tokens__ += list(node.tail)
77
+
78
+ def load_html_tree(self, node, parent=None):
79
+ ''' Converts HTML tree to the format required by apted
80
+ '''
81
+ global __tokens__
82
+ if node.tag == 'td':
83
+ if self.structure_only:
84
+ cell = []
85
+ else:
86
+ self.__tokens__ = []
87
+ self.tokenize(node)
88
+ cell = self.__tokens__[1:-1].copy()
89
+ new_node = TableTree(node.tag,
90
+ int(node.attrib.get('colspan', '1')),
91
+ int(node.attrib.get('rowspan', '1')),
92
+ cell, *deque())
93
+ else:
94
+ new_node = TableTree(node.tag, None, None, None, *deque())
95
+ if parent is not None:
96
+ parent.children.append(new_node)
97
+ if node.tag != 'td':
98
+ for n in node.getchildren():
99
+ self.load_html_tree(n, new_node)
100
+ if parent is None:
101
+ return new_node
102
+
103
+ def evaluate(self, pred, true):
104
+ ''' Computes TEDS score between the prediction and the ground truth of a
105
+ given sample
106
+ '''
107
+ if (not pred) or (not true):
108
+ return 0.0
109
+ pred = "<html>" + pred + "</html>"
110
+ true = "<html>" + true + "</html>"
111
+ parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
112
+ pred = html.fromstring(pred, parser=parser)
113
+ true = html.fromstring(true, parser=parser)
114
+ if pred.xpath('body/table') and true.xpath('body/table'):
115
+ pred = pred.xpath('body/table')[0]
116
+ true = true.xpath('body/table')[0]
117
+ if self.ignore_nodes:
118
+ etree.strip_tags(pred, *self.ignore_nodes)
119
+ etree.strip_tags(true, *self.ignore_nodes)
120
+ n_nodes_pred = len(pred.xpath(".//*"))
121
+ n_nodes_true = len(true.xpath(".//*"))
122
+ n_nodes = max(n_nodes_pred, n_nodes_true)
123
+ tree_pred = self.load_html_tree(pred)
124
+ tree_true = self.load_html_tree(true)
125
+ distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
126
+ return 1.0 - (float(distance) / n_nodes)
127
+ else:
128
+ return 0.0
129
+
130
+ def batch_evaluate(self, pred_json, true_json):
131
+ ''' Computes TEDS score between the prediction and the ground truth of
132
+ a batch of samples
133
+ @params pred_json: {'FILENAME': 'HTML CODE', ...}
134
+ @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
135
+ @output: {'FILENAME': 'TEDS SCORE', ...}
136
+ '''
137
+ samples = true_json.keys()
138
+ if self.n_jobs == 1:
139
+ scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
140
+ else:
141
+ inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
142
+ scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
143
+ total_score_simple = 0
144
+ num_simple = 0
145
+ total_score_complex = 0
146
+ num_complex = 0
147
+ total_score = 0
148
+ num_total = 0
149
+ for filename,score in zip(samples, scores):
150
+ print(filename)
151
+ print(score)
152
+ print('')
153
+ if true_json[filename]['type'] == 'simple':
154
+ total_score_simple += score
155
+ num_simple += 1
156
+ elif true_json[filename]['type'] == 'complex':
157
+ total_score_complex += score
158
+ num_complex += 1
159
+ else:
160
+ raise ValueError('Unknown type: %s' % true_json[filename]['type'])
161
+ total_score += score
162
+ num_total += 1
163
+ if num_simple > 0:
164
+ avg_score_simple = total_score_simple / num_simple
165
+ else:
166
+ avg_score_simple = 0
167
+ if num_complex > 0:
168
+ avg_score_complex = total_score_complex / num_complex
169
+ else:
170
+ avg_score_complex = 0
171
+ avg_score = total_score / num_total
172
+ print({'simple': (num_simple,avg_score_simple), 'complex': (num_complex,avg_score_complex), 'total': (num_total,avg_score)})
173
+
174
+ def main():
175
+ parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
176
+ parser.add_argument(
177
+ "workspace",
178
+ help="The filesystem path where work will be stored, can be a local folder",
179
+ )
180
+ parser.add_argument(
181
+ "--gt_file",
182
+ help="Ground truth file",
183
+ )
184
+ parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
185
+ args = parser.parse_args()
186
+
187
+ pred_data = {}
188
+ root_dir = os.path.join(args.workspace, "results")
189
+ for jsonl_file in os.listdir(root_dir):
190
+ if jsonl_file.endswith(".jsonl"):
191
+ with open(os.path.join(root_dir, jsonl_file), "r") as f:
192
+ for line in f:
193
+ data = json.loads(line)
194
+ pred_data[os.path.basename(data['orig_path'])] = data['document_text']
195
+
196
+ gt_data = {}
197
+ with open(args.gt_file, "r") as f:
198
+ for line in f:
199
+ data = json.loads(line)
200
+ gt_data[data['image_name']] = {'html':data['gt_table'], 'type':data['type']}
201
+
202
+ teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
203
+ teds.batch_evaluate(pred_data, gt_data)
204
+
205
+ if __name__ == "__main__":
206
+ main()
eval/eval_table_to_html_nanonets.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import distance
5
+ import markdown2
6
+ import re
7
+ from apted import APTED, Config
8
+ from apted.helpers import Tree
9
+ from lxml import etree, html
10
+ from collections import deque
11
+ from tqdm import tqdm
12
+ from eval.parallel import parallel_process
13
+ from bs4 import BeautifulSoup
14
+
15
+ def turn_header_to_h1(line):
16
+ # 检查是否是以一个或多个 '#' 开头的标题行
17
+ if line.lstrip().startswith('#'):
18
+ # 去掉开头的 '#' 和其后的空格
19
+ new_line = "# " + line.lstrip().lstrip('#').lstrip()
20
+ return new_line
21
+ else:
22
+ return line
23
+
24
+ def replace_single_dollar(markdown_text):
25
+ pattern = r'\$(.*?)\$'
26
+ def replace_with_brackets(match):
27
+ formula_content = match.group(1) # 获取匹配到的公式内容
28
+ return f'\\({formula_content}\\)'
29
+
30
+ replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
31
+
32
+ return replaced_text
33
+
34
+
35
+ def replace_double_dollar(markdown_text):
36
+ pattern = r'\$\$(.*?)\$\$'
37
+ def replace_with_brackets(match):
38
+ formula_content = match.group(1)
39
+ return f'\\[{formula_content}\\]'
40
+ replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
41
+
42
+ return replaced_text
43
+
44
+
45
+ class TableTree(Tree):
46
+ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
47
+ self.tag = tag
48
+ self.colspan = colspan
49
+ self.rowspan = rowspan
50
+ self.content = content
51
+ self.children = list(children)
52
+
53
+ def bracket(self):
54
+ """Show tree using brackets notation"""
55
+ if self.tag == 'td':
56
+ result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
57
+ (self.tag, self.colspan, self.rowspan, self.content)
58
+ else:
59
+ result = '"tag": %s' % self.tag
60
+ for child in self.children:
61
+ result += child.bracket()
62
+ return "{{{}}}".format(result)
63
+
64
+
65
+ class CustomConfig(Config):
66
+ @staticmethod
67
+ def maximum(*sequences):
68
+ """Get maximum possible value
69
+ """
70
+ return max(map(len, sequences))
71
+
72
+ def normalized_distance(self, *sequences):
73
+ """Get distance from 0 to 1
74
+ """
75
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
76
+
77
+ def rename(self, node1, node2):
78
+ """Compares attributes of trees"""
79
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
80
+ return 1.
81
+ if node1.tag == 'td':
82
+ if node1.content or node2.content:
83
+ return self.normalized_distance(node1.content, node2.content)
84
+ return 0.
85
+
86
+
87
+ class TEDS(object):
88
+ ''' Tree Edit Distance basead Similarity
89
+ '''
90
+ def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
91
+ assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
92
+ self.structure_only = structure_only
93
+ self.n_jobs = n_jobs
94
+ self.ignore_nodes = ignore_nodes
95
+ self.__tokens__ = []
96
+
97
+ def tokenize(self, node):
98
+ ''' Tokenizes table cells
99
+ '''
100
+ self.__tokens__.append('<%s>' % node.tag)
101
+ if node.text is not None:
102
+ self.__tokens__ += list(node.text)
103
+ for n in node.getchildren():
104
+ self.tokenize(n)
105
+ if node.tag != 'unk':
106
+ self.__tokens__.append('</%s>' % node.tag)
107
+ if node.tag != 'td' and node.tail is not None:
108
+ self.__tokens__ += list(node.tail)
109
+
110
+ def load_html_tree(self, node, parent=None):
111
+ ''' Converts HTML tree to the format required by apted
112
+ '''
113
+ global __tokens__
114
+ if node.tag == 'td':
115
+ if self.structure_only:
116
+ cell = []
117
+ else:
118
+ self.__tokens__ = []
119
+ self.tokenize(node)
120
+ cell = self.__tokens__[1:-1].copy()
121
+ new_node = TableTree(node.tag,
122
+ int(node.attrib.get('colspan', '1')),
123
+ int(node.attrib.get('rowspan', '1')),
124
+ cell, *deque())
125
+ else:
126
+ new_node = TableTree(node.tag, None, None, None, *deque())
127
+ if parent is not None:
128
+ parent.children.append(new_node)
129
+ if node.tag != 'td':
130
+ for n in node.getchildren():
131
+ self.load_html_tree(n, new_node)
132
+ if parent is None:
133
+ return new_node
134
+
135
+ def evaluate(self, pred, true):
136
+ ''' Computes TEDS score between the prediction and the ground truth of a
137
+ given sample
138
+ '''
139
+ if (not pred) or (not true):
140
+ return 0.0
141
+ pred.replace("<th>","<td>")
142
+ pred.replace("</th>","</td>")
143
+ pred = "<html>" + pred + "</html>"
144
+ true = "<html>" + true + "</html>"
145
+ parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
146
+ pred = html.fromstring(pred, parser=parser)
147
+ true = html.fromstring(true, parser=parser)
148
+ if pred.xpath('body/table') and true.xpath('body/table'):
149
+ pred = pred.xpath('body/table')[0]
150
+ true = true.xpath('body/table')[0]
151
+ if self.ignore_nodes:
152
+ etree.strip_tags(pred, *self.ignore_nodes)
153
+ etree.strip_tags(true, *self.ignore_nodes)
154
+ n_nodes_pred = len(pred.xpath(".//*"))
155
+ n_nodes_true = len(true.xpath(".//*"))
156
+ n_nodes = max(n_nodes_pred, n_nodes_true)
157
+ tree_pred = self.load_html_tree(pred)
158
+ tree_true = self.load_html_tree(true)
159
+ distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
160
+ return 1.0 - (float(distance) / n_nodes)
161
+ else:
162
+ return 0.0
163
+
164
+ def batch_evaluate(self, pred_json, true_json):
165
+ ''' Computes TEDS score between the prediction and the ground truth of
166
+ a batch of samples
167
+ @params pred_json: {'FILENAME': 'HTML CODE', ...}
168
+ @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
169
+ @output: {'FILENAME': 'TEDS SCORE', ...}
170
+ '''
171
+ samples = true_json.keys()
172
+ if self.n_jobs == 1:
173
+ scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
174
+ else:
175
+ inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
176
+ scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
177
+ total_score_simple = 0
178
+ num_simple = 0
179
+ total_score_complex = 0
180
+ num_complex = 0
181
+ total_score = 0
182
+ num_total = 0
183
+ for filename,score in zip(samples, scores):
184
+ print(filename)
185
+ print(score)
186
+ print('')
187
+ if true_json[filename]['type'] == 'simple':
188
+ total_score_simple += score
189
+ num_simple += 1
190
+ elif true_json[filename]['type'] == 'complex':
191
+ total_score_complex += score
192
+ num_complex += 1
193
+ else:
194
+ raise ValueError('Unknown type: %s' % true_json[filename]['type'])
195
+ total_score += score
196
+ num_total += 1
197
+ if num_simple > 0:
198
+ avg_score_simple = total_score_simple / num_simple
199
+ else:
200
+ avg_score_simple = 0
201
+ if num_complex > 0:
202
+ avg_score_complex = total_score_complex / num_complex
203
+ else:
204
+ avg_score_complex = 0
205
+ avg_score = total_score / num_total
206
+ print({'simple': (num_simple,avg_score_simple), 'complex': (num_complex,avg_score_complex), 'total': (num_total,avg_score)})
207
+
208
+ def simplify_html_table(html_table):
209
+ # 使用 BeautifulSoup 解析 HTML
210
+ soup = BeautifulSoup(html_table, 'html.parser')
211
+
212
+ # 找到 <table> 标签
213
+ table = soup.find('table')
214
+ if not table:
215
+ raise ValueError("输入的 HTML 不包含有效的 <table> 标签")
216
+
217
+ # 创建一个新的 <table> 标签
218
+ new_table = BeautifulSoup('<table></table>', 'html.parser').table
219
+
220
+ # 提取所有行(包括 <thead> 和 <tbody> 中的行)
221
+ rows = table.find_all(['tr'], recursive=True)
222
+
223
+ for row in rows:
224
+ # 创建新的 <tr> 标签
225
+ new_row = soup.new_tag('tr')
226
+
227
+ # 处理每一行中的单元格
228
+ cells = row.find_all(['th', 'td'])
229
+ for cell in cells:
230
+ # 将 <th> 替换为 <td>
231
+ new_cell = soup.new_tag('td')
232
+ if cell.has_attr('rowspan'):
233
+ new_cell['rowspan'] = cell['rowspan']
234
+ if cell.has_attr('colspan'):
235
+ new_cell['colspan'] = cell['colspan']
236
+ new_cell.string = cell.get_text(strip=True) # 保留单元格内容
237
+ new_row.append(new_cell)
238
+
239
+ # 将新行添加到新表格中
240
+ new_table.append(new_row)
241
+
242
+ # 返回简化后的表格 HTML
243
+ return str(new_table)
244
+
245
+
246
+ def main():
247
+ parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
248
+ parser.add_argument(
249
+ "workspace",
250
+ help="The filesystem path where work will be stored, can be a local folder",
251
+ )
252
+ parser.add_argument(
253
+ "--gt_file",
254
+ help="Ground truth file",
255
+ )
256
+ parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
257
+ args = parser.parse_args()
258
+
259
+ pred_data = {}
260
+ for file in os.listdir(args.workspace):
261
+ file_path = os.path.join(args.workspace, file)
262
+ pdf_name = file.split('.')[0] + ".png"
263
+ with open(file_path, "r") as f:
264
+ document_text = f.read()
265
+ document_text = replace_single_dollar(replace_double_dollar(document_text))
266
+ markdown_text_list = document_text.split("\n\n")
267
+ new_markdown_text_list = []
268
+ for text in markdown_text_list:
269
+ text = text.strip()
270
+ if (text.startswith("<watermark>") and text.endswith("</watermark>")) or (text.startswith("<img>") and text.endswith("</img>")) or (text.startswith("<page_number>") and text.endswith("</page_number>")) or (text.startswith("<signature>") and text.endswith("</signature>")):
271
+ continue
272
+ else:
273
+ html_text = str(markdown2.markdown(text,extras=["tables"]))
274
+ html_text = html_text.strip()
275
+ if html_text.startswith("<table>") and html_text.endswith("</table>"):
276
+ html_table = simplify_html_table(html_text)
277
+ new_markdown_text_list.append(html_table)
278
+ else:
279
+ text = turn_header_to_h1(text)
280
+ new_markdown_text_list.append(text)
281
+
282
+ pred_data[os.path.basename(pdf_name)] = "\n\n".join(new_markdown_text_list)
283
+
284
+
285
+ gt_data = {}
286
+ with open(args.gt_file, "r") as f:
287
+ for line in f:
288
+ data = json.loads(line)
289
+ gt_data[data['image_name']] = {'html':data['gt_table'], 'type':data['type']}
290
+
291
+ teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
292
+ teds.batch_evaluate(pred_data, gt_data)
293
+
294
+ if __name__ == "__main__":
295
+ main()
eval/eval_table_to_html_olmocr.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import distance
5
+ import markdown2
6
+ from apted import APTED, Config
7
+ from apted.helpers import Tree
8
+ from lxml import etree, html
9
+ from collections import deque
10
+ from tqdm import tqdm
11
+ from eval.parallel import parallel_process
12
+
13
+
14
+ class TableTree(Tree):
15
+ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
16
+ self.tag = tag
17
+ self.colspan = colspan
18
+ self.rowspan = rowspan
19
+ self.content = content
20
+ self.children = list(children)
21
+
22
+ def bracket(self):
23
+ """Show tree using brackets notation"""
24
+ if self.tag == 'td':
25
+ result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
26
+ (self.tag, self.colspan, self.rowspan, self.content)
27
+ else:
28
+ result = '"tag": %s' % self.tag
29
+ for child in self.children:
30
+ result += child.bracket()
31
+ return "{{{}}}".format(result)
32
+
33
+
34
+ class CustomConfig(Config):
35
+ @staticmethod
36
+ def maximum(*sequences):
37
+ """Get maximum possible value
38
+ """
39
+ return max(map(len, sequences))
40
+
41
+ def normalized_distance(self, *sequences):
42
+ """Get distance from 0 to 1
43
+ """
44
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
45
+
46
+ def rename(self, node1, node2):
47
+ """Compares attributes of trees"""
48
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
49
+ return 1.
50
+ if node1.tag == 'td':
51
+ if node1.content or node2.content:
52
+ return self.normalized_distance(node1.content, node2.content)
53
+ return 0.
54
+
55
+
56
+ class TEDS(object):
57
+ ''' Tree Edit Distance basead Similarity
58
+ '''
59
+ def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
60
+ assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
61
+ self.structure_only = structure_only
62
+ self.n_jobs = n_jobs
63
+ self.ignore_nodes = ignore_nodes
64
+ self.__tokens__ = []
65
+
66
+ def tokenize(self, node):
67
+ ''' Tokenizes table cells
68
+ '''
69
+ self.__tokens__.append('<%s>' % node.tag)
70
+ if node.text is not None:
71
+ self.__tokens__ += list(node.text)
72
+ for n in node.getchildren():
73
+ self.tokenize(n)
74
+ if node.tag != 'unk':
75
+ self.__tokens__.append('</%s>' % node.tag)
76
+ if node.tag != 'td' and node.tail is not None:
77
+ self.__tokens__ += list(node.tail)
78
+
79
+ def load_html_tree(self, node, parent=None):
80
+ ''' Converts HTML tree to the format required by apted
81
+ '''
82
+ global __tokens__
83
+ if node.tag == 'td':
84
+ if self.structure_only:
85
+ cell = []
86
+ else:
87
+ self.__tokens__ = []
88
+ self.tokenize(node)
89
+ cell = self.__tokens__[1:-1].copy()
90
+ new_node = TableTree(node.tag,
91
+ int(node.attrib.get('colspan', '1')),
92
+ int(node.attrib.get('rowspan', '1')),
93
+ cell, *deque())
94
+ else:
95
+ new_node = TableTree(node.tag, None, None, None, *deque())
96
+ if parent is not None:
97
+ parent.children.append(new_node)
98
+ if node.tag != 'td':
99
+ for n in node.getchildren():
100
+ self.load_html_tree(n, new_node)
101
+ if parent is None:
102
+ return new_node
103
+
104
+ def evaluate(self, pred, true):
105
+ ''' Computes TEDS score between the prediction and the ground truth of a
106
+ given sample
107
+ '''
108
+ if (not pred) or (not true):
109
+ return 0.0
110
+ pred.replace("<th>","<td>")
111
+ pred.replace("</th>","</td>")
112
+ pred = "<html>" + pred + "</html>"
113
+ true = "<html>" + true + "</html>"
114
+ parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
115
+ pred = html.fromstring(pred, parser=parser)
116
+ true = html.fromstring(true, parser=parser)
117
+ if pred.xpath('body/table') and true.xpath('body/table'):
118
+ pred = pred.xpath('body/table')[0]
119
+ true = true.xpath('body/table')[0]
120
+ if self.ignore_nodes:
121
+ etree.strip_tags(pred, *self.ignore_nodes)
122
+ etree.strip_tags(true, *self.ignore_nodes)
123
+ n_nodes_pred = len(pred.xpath(".//*"))
124
+ n_nodes_true = len(true.xpath(".//*"))
125
+ n_nodes = max(n_nodes_pred, n_nodes_true)
126
+ tree_pred = self.load_html_tree(pred)
127
+ tree_true = self.load_html_tree(true)
128
+ distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
129
+ return 1.0 - (float(distance) / n_nodes)
130
+ else:
131
+ return 0.0
132
+
133
+ def batch_evaluate(self, pred_json, true_json):
134
+ ''' Computes TEDS score between the prediction and the ground truth of
135
+ a batch of samples
136
+ @params pred_json: {'FILENAME': 'HTML CODE', ...}
137
+ @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
138
+ @output: {'FILENAME': 'TEDS SCORE', ...}
139
+ '''
140
+ samples = true_json.keys()
141
+ if self.n_jobs == 1:
142
+ scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
143
+ else:
144
+ inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
145
+ scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
146
+ total_score_simple = 0
147
+ num_simple = 0
148
+ total_score_complex = 0
149
+ num_complex = 0
150
+ total_score = 0
151
+ num_total = 0
152
+ for filename,score in zip(samples, scores):
153
+ print(filename)
154
+ print(score)
155
+ print('')
156
+ if true_json[filename]['type'] == 'simple':
157
+ total_score_simple += score
158
+ num_simple += 1
159
+ elif true_json[filename]['type'] == 'complex':
160
+ total_score_complex += score
161
+ num_complex += 1
162
+ else:
163
+ raise ValueError('Unknown type: %s' % true_json[filename]['type'])
164
+ total_score += score
165
+ num_total += 1
166
+ if num_simple > 0:
167
+ avg_score_simple = total_score_simple / num_simple
168
+ else:
169
+ avg_score_simple = 0
170
+ if num_complex > 0:
171
+ avg_score_complex = total_score_complex / num_complex
172
+ else:
173
+ avg_score_complex = 0
174
+ avg_score = total_score / num_total
175
+ print({'simple': (num_simple,avg_score_simple), 'complex': (num_complex,avg_score_complex), 'total': (num_total,avg_score)})
176
+
177
+ def main():
178
+ parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
179
+ parser.add_argument(
180
+ "workspace",
181
+ help="The filesystem path where work will be stored, can be a local folder",
182
+ )
183
+ parser.add_argument(
184
+ "--gt_file",
185
+ help="Ground truth file",
186
+ )
187
+ parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
188
+ args = parser.parse_args()
189
+
190
+ pred_data = {}
191
+ root_dir = os.path.join(args.workspace, "results")
192
+ for jsonl_file in os.listdir(root_dir):
193
+ if jsonl_file.endswith(".jsonl"):
194
+ with open(os.path.join(root_dir, jsonl_file), "r") as f:
195
+ for line in f:
196
+ data = json.loads(line)
197
+ pdf_path = os.path.basename(data['metadata']['Source-File'])
198
+ document_text = data['text']
199
+ pred_data[pdf_path] = str(markdown2.markdown(document_text,extras=["tables"]))
200
+
201
+
202
+ gt_data = {}
203
+ with open(args.gt_file, "r") as f:
204
+ for line in f:
205
+ data = json.loads(line)
206
+ gt_data[data['image_name']] = {'html':data['gt_table'], 'type':data['type']}
207
+
208
+ teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
209
+ teds.batch_evaluate(pred_data, gt_data)
210
+
211
+ if __name__ == "__main__":
212
+ main()
eval/gen_element_merge_detect_data.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task")
7
+ parser.add_argument(
8
+ "workspace",
9
+ help="The filesystem path where work will be stored, can be a local folder",
10
+ )
11
+ args = parser.parse_args()
12
+
13
+ json_dir = os.path.join(args.workspace, 'jsons')
14
+ if not os.path.exists(json_dir):
15
+ os.makedirs(json_dir)
16
+
17
+ jsonl_file = os.path.join(args.workspace, "data.jsonl")
18
+ with open(jsonl_file, "r") as f:
19
+ for line in f:
20
+ data = json.loads(line)
21
+ pdf_name_1 = data['pdf_name_1'].split(".")[0]
22
+ pdf_name_2 = data['pdf_name_2'].split(".")[0]
23
+
24
+ pdf_name,page_1 = pdf_name_1.split('_')
25
+ pdf_name,page_2 = pdf_name_2.split('_')
26
+
27
+ json_name = os.path.join(json_dir, pdf_name + '_' + page_1 + '_' + page_2 + '.json')
28
+ data = {
29
+ "page_1": "\n\n".join(data['md_elem_list_1']),
30
+ "page_2": "\n\n".join(data['md_elem_list_2']),
31
+ }
32
+ with open(json_name, 'w') as f:
33
+ json.dump(data, f)
34
+
35
+ if __name__ == "__main__":
36
+ main()
eval/gen_html_table_merge_data.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task")
7
+ parser.add_argument(
8
+ "workspace",
9
+ help="The filesystem path where work will be stored, can be a local folder",
10
+ )
11
+ args = parser.parse_args()
12
+
13
+ json_dir = os.path.join(args.workspace, 'jsons')
14
+ if not os.path.exists(json_dir):
15
+ os.makedirs(json_dir)
16
+
17
+ jsonl_file = os.path.join(args.workspace, 'data.jsonl')
18
+ with open(jsonl_file, "r") as f:
19
+ for line in f:
20
+ data = json.loads(line)
21
+ json_name = data['image_name'].split('.')[0] + '.json'
22
+
23
+ json_path = os.path.join(json_dir, json_name)
24
+ data = {
25
+ "table_1": data['table_fragment_1'],
26
+ "table_2": data['table_fragment_2'],
27
+ }
28
+ with open(json_path, 'w') as f:
29
+ json.dump(data, f)
30
+
31
+ if __name__ == "__main__":
32
+ main()
eval/parallel.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from concurrent.futures import ProcessPoolExecutor, as_completed
3
+
4
+ def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
5
+ """
6
+ A parallel version of the map function with a progress bar.
7
+
8
+ Args:
9
+ array (array-like): An array to iterate over.
10
+ function (function): A python function to apply to the elements of array
11
+ n_jobs (int, default=16): The number of cores to use
12
+ use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
13
+ keyword arguments to function
14
+ front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
15
+ Useful for catching bugs
16
+ Returns:
17
+ [function(array[0]), function(array[1]), ...]
18
+ """
19
+ # We run the first few iterations serially to catch bugs
20
+ if front_num > 0:
21
+ front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
22
+ else:
23
+ front = []
24
+ # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
25
+ if n_jobs == 1:
26
+ return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
27
+ # Assemble the workers
28
+ with ProcessPoolExecutor(max_workers=n_jobs) as pool:
29
+ # Pass the elements of array into function
30
+ if use_kwargs:
31
+ futures = [pool.submit(function, **a) for a in array[front_num:]]
32
+ else:
33
+ futures = [pool.submit(function, a) for a in array[front_num:]]
34
+ kwargs = {
35
+ 'total': len(futures),
36
+ 'unit': 'it',
37
+ 'unit_scale': True,
38
+ 'leave': True
39
+ }
40
+ # Print out the progress as tasks complete
41
+ for f in tqdm(as_completed(futures), **kwargs):
42
+ pass
43
+ out = []
44
+ # Get the results from the futures.
45
+ for i, future in tqdm(enumerate(futures)):
46
+ try:
47
+ out.append(future.result())
48
+ except Exception as e:
49
+ out.append(e)
50
+ return front + out
ocrflux/check.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import logging
3
+ import subprocess
4
+ import sys
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def check_poppler_version():
10
+ try:
11
+ result = subprocess.run(["pdftoppm", "-h"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
12
+ if result.returncode == 0 and result.stderr.startswith("pdftoppm"):
13
+ logger.info("pdftoppm is installed and working.")
14
+ else:
15
+ logger.error("pdftoppm is installed but returned an error.")
16
+ sys.exit(1)
17
+ except FileNotFoundError:
18
+ logger.error("pdftoppm is not installed.")
19
+ sys.exit(1)
20
+
21
+ def check_vllm_version():
22
+ if importlib.util.find_spec("vllm") is None:
23
+ logger.error("VLLM needs to be installed with a separate command in order to find all dependencies properly.")
24
+ sys.exit(1)
25
+
26
+
27
+ def check_torch_gpu_available(min_gpu_memory: int = 20 * 1024**3):
28
+ try:
29
+ import torch
30
+ except:
31
+ logger.error("Pytorch must be installed, visit https://pytorch.org/ for installation instructions")
32
+ raise
33
+
34
+ try:
35
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory
36
+ assert gpu_memory >= min_gpu_memory
37
+ except:
38
+ logger.error(f"Torch was not able to find a GPU with at least {min_gpu_memory // (1024 ** 3)} GB of RAM.")
39
+ raise
40
+
41
+
42
+ if __name__ == "__main__":
43
+ check_poppler_version()
44
+ check_vllm_version()
ocrflux/image_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import io
4
+ from typing import List, Union
5
+ from PIL import Image
6
+
7
+
8
+ def get_page_image(pdf_path, page_number, target_longest_image_dim=None, image_rotation=0):
9
+ if pdf_path.lower().endswith(".pdf"):
10
+ # Convert PDF page to PNG using pdftoppm
11
+ pdftoppm_result = subprocess.run(
12
+ [
13
+ "pdftoppm",
14
+ "-png",
15
+ "-f",
16
+ str(page_number),
17
+ "-l",
18
+ str(page_number),
19
+ "-r",
20
+ "72", # 72 pixels per point is the conversion factor
21
+ pdf_path,
22
+ ],
23
+ timeout=120,
24
+ stdout=subprocess.PIPE,
25
+ stderr=subprocess.PIPE,
26
+ )
27
+ assert pdftoppm_result.returncode == 0, pdftoppm_result.stderr
28
+ image = Image.open(io.BytesIO(pdftoppm_result.stdout))
29
+ else:
30
+ image = Image.open(pdf_path)
31
+ if image_rotation != 0:
32
+ image = image.rotate(-image_rotation, expand=True)
33
+ if target_longest_image_dim is not None:
34
+ width, height = image.size
35
+ if width > height:
36
+ new_width = target_longest_image_dim
37
+ new_height = int(height * (target_longest_image_dim / width))
38
+ else:
39
+ new_height = target_longest_image_dim
40
+ new_width = int(width * (target_longest_image_dim / height))
41
+ image = image.resize((new_width, new_height))
42
+ return image
43
+
44
+
45
+ def is_image(file_path):
46
+ try:
47
+ Image.open(file_path)
48
+ return True
49
+ except:
50
+ return False
ocrflux/inference.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import copy
3
+ from PIL import Image
4
+ from pypdf import PdfReader
5
+ from vllm import LLM, SamplingParams
6
+ from ocrflux.image_utils import get_page_image
7
+ from ocrflux.table_format import table_matrix2html
8
+ from ocrflux.prompts import PageResponse, build_page_to_markdown_prompt, build_element_merge_detect_prompt, build_html_table_merge_prompt
9
+
10
+ def build_qwen2_5_vl_prompt(question):
11
+ return (
12
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
13
+ f"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
14
+ f"{question}<|im_end|>\n"
15
+ "<|im_start|>assistant\n"
16
+ )
17
+
18
+ def build_page_to_markdown_query(file_path: str, page_number: int, target_longest_image_dim: int = 1024, image_rotation: int = 0) -> dict:
19
+ assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
20
+ image = get_page_image(file_path, page_number, target_longest_image_dim=target_longest_image_dim, image_rotation=image_rotation)
21
+ question = build_page_to_markdown_prompt()
22
+ prompt = build_qwen2_5_vl_prompt(question)
23
+ query = {
24
+ "prompt": prompt,
25
+ "multi_modal_data": {"image": image},
26
+ }
27
+ return query
28
+
29
+ def build_element_merge_detect_query(text_list_1,text_list_2) -> dict:
30
+ image = Image.new('RGB', (28, 28), color='black')
31
+ question = build_element_merge_detect_prompt(text_list_1,text_list_2)
32
+ prompt = build_qwen2_5_vl_prompt(question)
33
+ query = {
34
+ "prompt": prompt,
35
+ "multi_modal_data": {"image": image},
36
+ }
37
+ return query
38
+
39
+ def build_html_table_merge_query(text_1,text_2) -> dict:
40
+ image = Image.new('RGB', (28, 28), color='black')
41
+ question = build_html_table_merge_prompt(text_1,text_2)
42
+ prompt = build_qwen2_5_vl_prompt(question)
43
+ query = {
44
+ "prompt": prompt,
45
+ "multi_modal_data": {"image": image},
46
+ }
47
+ return query
48
+
49
+ def bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result):
50
+ page_to_markdown_keys = list(page_to_markdown_result.keys())
51
+ element_merge_detect_keys = list(element_merge_detect_result.keys())
52
+ html_table_merge_keys = list(html_table_merge_result.keys())
53
+
54
+ for page_1,page_2,elem_idx_1,elem_idx_2 in sorted(html_table_merge_keys,key=lambda x: -x[0]):
55
+ page_to_markdown_result[page_1][elem_idx_1] = html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)]
56
+ page_to_markdown_result[page_2][elem_idx_2] = ''
57
+
58
+ for page_1,page_2 in sorted(element_merge_detect_keys,key=lambda x: -x[0]):
59
+ for elem_idx_1,elem_idx_2 in element_merge_detect_result[(page_1,page_2)]:
60
+ if len(page_to_markdown_result[page_1][elem_idx_1]) == 0 or page_to_markdown_result[page_1][elem_idx_1][-1] == '-' or ('\u4e00' <= page_to_markdown_result[page_1][elem_idx_1][-1] <= '\u9fff'):
61
+ page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + '' + page_to_markdown_result[page_2][elem_idx_2]
62
+ else:
63
+ page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + ' ' + page_to_markdown_result[page_2][elem_idx_2]
64
+ page_to_markdown_result[page_2][elem_idx_2] = ''
65
+
66
+ document_text_list = []
67
+ for page in page_to_markdown_keys:
68
+ page_text_list = [s for s in page_to_markdown_result[page] if s]
69
+ document_text_list += page_text_list
70
+ return "\n\n".join(document_text_list)
71
+
72
+ def parse(llm,file_path,skip_cross_page_merge=False,max_page_retries=0):
73
+ sampling_params = SamplingParams(temperature=0.0,max_tokens=8192)
74
+ if file_path.lower().endswith(".pdf"):
75
+ try:
76
+ reader = PdfReader(file_path)
77
+ num_pages = reader.get_num_pages()
78
+ except:
79
+ return None
80
+ else:
81
+ num_pages = 1
82
+
83
+ try:
84
+ # Stage 1: Page to Markdown
85
+ page_to_markdown_query_list = [build_page_to_markdown_query(file_path,page_num) for page_num in range(1, num_pages + 1)]
86
+ responses = llm.generate(page_to_markdown_query_list, sampling_params=sampling_params)
87
+ results = [response.outputs[0].text for response in responses]
88
+ page_to_markdown_result = {}
89
+ retry_list = []
90
+ for i,result in enumerate(results):
91
+ try:
92
+ json_data = json.loads(result)
93
+ page_response = PageResponse(**json_data)
94
+ natural_text = page_response.natural_text
95
+ markdown_element_list = []
96
+ for text in natural_text.split('\n\n'):
97
+ if text.startswith("<Image>") and text.endswith("</Image>"):
98
+ pass
99
+ elif text.startswith("<table>") and text.endswith("</table>"):
100
+ try:
101
+ new_text = table_matrix2html(text)
102
+ except:
103
+ new_text = text.replace("<t>","").replace("<l>","").replace("<lt>","")
104
+ markdown_element_list.append(new_text)
105
+ else:
106
+ markdown_element_list.append(text)
107
+ page_to_markdown_result[i+1] = markdown_element_list
108
+ except:
109
+ retry_list.append(i)
110
+
111
+ attempt = 0
112
+ while len(retry_list) > 0 and attempt < max_page_retries:
113
+ retry_page_to_markdown_query_list = [build_page_to_markdown_query(file_path,page_num) for page_num in retry_list]
114
+ retry_sampling_params = SamplingParams(temperature=0.1*attempt, max_tokens=8192)
115
+ responses = llm.generate(retry_page_to_markdown_query_list, sampling_params=retry_sampling_params)
116
+ results = [response.outputs[0].text for response in responses]
117
+ next_retry_list = []
118
+ for i,result in zip(retry_list,results):
119
+ try:
120
+ json_data = json.loads(result)
121
+ page_response = PageResponse(**json_data)
122
+ natural_text = page_response.natural_text
123
+ markdown_element_list = []
124
+ for text in natural_text.split('\n\n'):
125
+ if text.startswith("<Image>") and text.endswith("</Image>"):
126
+ pass
127
+ elif text.startswith("<table>") and text.endswith("</table>"):
128
+ try:
129
+ new_text = table_matrix2html(text)
130
+ except:
131
+ new_text = text.replace("<t>","").replace("<l>","").replace("<lt>","")
132
+ markdown_element_list.append(new_text)
133
+ else:
134
+ markdown_element_list.append(text)
135
+ page_to_markdown_result[i+1] = markdown_element_list
136
+ except:
137
+ next_retry_list.append(i)
138
+ retry_list = next_retry_list
139
+ attempt += 1
140
+
141
+ page_texts = {}
142
+ fallback_pages = []
143
+ for page_number in range(1, num_pages+1):
144
+ if page_number not in page_to_markdown_result.keys():
145
+ fallback_pages.append(page_number-1)
146
+ else:
147
+ page_texts[str(page_number-1)] = "\n\n".join(page_to_markdown_result[page_number])
148
+
149
+ if skip_cross_page_merge:
150
+ document_text_list = []
151
+ for i in range(num_pages):
152
+ if i not in fallback_pages:
153
+ document_text_list.append(page_texts[str(i)])
154
+ document_text = "\n\n".join(document_text_list)
155
+ return {
156
+ "orig_path": file_path,
157
+ "num_pages": num_pages,
158
+ "document_text": document_text,
159
+ "page_texts": page_texts,
160
+ "fallback_pages": fallback_pages,
161
+ }
162
+
163
+ # Stage 2: Element Merge Detect
164
+ element_merge_detect_keys = []
165
+ element_merge_detect_query_list = []
166
+ for page_num in range(1,num_pages):
167
+ if page_num in page_to_markdown_result.keys() and page_num+1 in page_to_markdown_result.keys():
168
+ element_merge_detect_query_list.append(build_element_merge_detect_query(page_to_markdown_result[page_num],page_to_markdown_result[page_num+1]))
169
+ element_merge_detect_keys.append((page_num,page_num+1))
170
+ responses = llm.generate(element_merge_detect_query_list, sampling_params=sampling_params)
171
+ results = [response.outputs[0].text for response in responses]
172
+ element_merge_detect_result = {}
173
+ for key,result in zip(element_merge_detect_keys,results):
174
+ try:
175
+ element_merge_detect_result[key] = eval(result)
176
+ except:
177
+ pass
178
+
179
+ # Stage 3: HTML Table Merge
180
+ html_table_merge_keys = []
181
+ for key,result in element_merge_detect_result.items():
182
+ page_1,page_2 = key
183
+ for elem_idx_1,elem_idx_2 in result:
184
+ text_1 = page_to_markdown_result[page_1][elem_idx_1]
185
+ text_2 = page_to_markdown_result[page_2][elem_idx_2]
186
+ if text_1.startswith("<table>") and text_1.endswith("</table>") and text_2.startswith("<table>") and text_2.endswith("</table>"):
187
+ html_table_merge_keys.append((page_1,page_2,elem_idx_1,elem_idx_2))
188
+
189
+ html_table_merge_keys = sorted(html_table_merge_keys,key=lambda x: -x[0])
190
+
191
+ html_table_merge_result = {}
192
+ page_to_markdown_result_tmp = copy.deepcopy(page_to_markdown_result)
193
+ i = 0
194
+ while i < len(html_table_merge_keys):
195
+ tmp = set()
196
+ keys = []
197
+ while i < len(html_table_merge_keys):
198
+ page_1,page_2,elem_idx_1,elem_idx_2 = html_table_merge_keys[i]
199
+ if (page_2,elem_idx_2) in tmp:
200
+ break
201
+ tmp.add((page_1,elem_idx_1))
202
+ keys.append((page_1,page_2,elem_idx_1,elem_idx_2))
203
+ i += 1
204
+
205
+ html_table_merge_query_list = [build_html_table_merge_query(page_to_markdown_result_tmp[page_1][elem_idx_1],page_to_markdown_result_tmp[page_2][elem_idx_2]) for page_1,page_2,elem_idx_1,elem_idx_2 in keys]
206
+ responses = llm.generate(html_table_merge_query_list, sampling_params=sampling_params)
207
+ results = [response.outputs[0].text for response in responses]
208
+ for key,result in zip(keys,results):
209
+ if result.startswith("<table>") and result.endswith("</table>"):
210
+ html_table_merge_result[key] = result
211
+ page_to_markdown_result_tmp[page_1][elem_idx_1] = result
212
+
213
+ document_text = bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result)
214
+ return {
215
+ "orig_path": file_path,
216
+ "num_pages": num_pages,
217
+ "document_text": document_text,
218
+ "page_texts": page_texts,
219
+ "fallback_pages": fallback_pages,
220
+ }
221
+ except:
222
+ return None
223
+
224
+
225
+ if __name__ == '__main__':
226
+ file_path = 'test.pdf'
227
+ llm = LLM(model="ChatDOC/OCRFlux-3B",gpu_memory_utilization=0.8,max_model_len=8192)
228
+ result = parse(llm,file_path,max_page_retries=4)
229
+ if result != None:
230
+ document_markdown = result['document_text']
231
+ print(document_markdown)
232
+ with open('test.md','w') as f:
233
+ f.write(document_markdown)
234
+ else:
235
+ print("Parse failed")
236
+
237
+
ocrflux/jsonl_to_markdown.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ def main():
5
+ parser = argparse.ArgumentParser(description="Evaluate page_to_markdown task")
6
+ parser.add_argument(
7
+ "workspace",
8
+ help="The filesystem path where work will be stored, can be a local folder",
9
+ )
10
+ parser.add_argument("--show_page_result", action="store_true", help="Whether to show the markdown of each page")
11
+ args = parser.parse_args()
12
+
13
+ src_dir = os.path.join(args.workspace, "results")
14
+ tgt_dir = os.path.join(args.workspace, "markdowns")
15
+ if not os.path.exists(tgt_dir):
16
+ os.makedirs(tgt_dir)
17
+ for jsonl_file in os.listdir(src_dir):
18
+ if jsonl_file.endswith(".jsonl"):
19
+ with open(os.path.join(src_dir, jsonl_file), "r") as f:
20
+ for line in f:
21
+ data = json.loads(line)
22
+ markdown_text = data['document_text']
23
+ file_name = os.path.basename(data['orig_path']).split(".")[0]
24
+ file_dir = os.path.join(tgt_dir, file_name)
25
+ if not os.path.exists(file_dir):
26
+ os.makedirs(file_dir)
27
+ with open(os.path.join(file_dir, file_name+".md"), "w") as f:
28
+ f.write(markdown_text)
29
+ if args.show_page_result:
30
+ page_texts = data["page_texts"]
31
+ for page_num in page_texts.keys():
32
+ page_text = page_texts[page_num]
33
+ with open(os.path.join(file_dir, file_name+"_"+str(page_num)+".md"), "w") as f:
34
+ f.write(page_text)
35
+
36
+ if __name__ == "__main__":
37
+ main()
ocrflux/metrics.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ from collections import defaultdict, deque
4
+ from typing import Any, Deque, Dict, List, Set
5
+
6
+
7
+ class MetricsKeeper:
8
+ def __init__(self, window=60 * 5):
9
+ """
10
+ Initializes the MetricsKeeper.
11
+
12
+ Args:
13
+ window (int): Time window in seconds for recent metrics. Defaults to 5 minutes.
14
+ """
15
+ self.window = window # Time window in seconds
16
+ self.start_time = time.time() # Timestamp when MetricsKeeper was created
17
+ self.total_metrics = defaultdict(int) # Cumulative metrics since start
18
+ self.window_metrics: Deque[Any] = deque() # Deque to store (timestamp, metrics_dict)
19
+ self.window_sum = defaultdict(int) # Sum of metrics within the window
20
+
21
+ def add_metrics(self, **kwargs):
22
+ """
23
+ Adds metrics to the keeper.
24
+
25
+ Args:
26
+ **kwargs: Arbitrary keyword arguments representing metric names and their values.
27
+ """
28
+ current_time = time.time()
29
+ # Update cumulative metrics
30
+ for key, value in kwargs.items():
31
+ self.total_metrics[key] += value
32
+
33
+ # Append current metrics with timestamp to the deque
34
+ self.window_metrics.append((current_time, kwargs))
35
+
36
+ # Update window sums
37
+ for key, value in kwargs.items():
38
+ self.window_sum[key] += value
39
+
40
+ # Remove metrics that are outside the time window
41
+ while self.window_metrics and self.window_metrics[0][0] < current_time - self.window:
42
+ old_time, old_metrics = self.window_metrics.popleft()
43
+ for key, value in old_metrics.items():
44
+ self.window_sum[key] -= value
45
+ if self.window_sum[key] <= 0:
46
+ del self.window_sum[key] # Clean up to prevent negative counts
47
+
48
+ def __str__(self):
49
+ """
50
+ Returns a formatted string of metrics showing tokens/sec since start and within the window.
51
+
52
+ Returns:
53
+ str: Formatted metrics string as a table.
54
+ """
55
+ current_time = time.time()
56
+ elapsed_time = current_time - self.start_time
57
+ window_time = min(self.window, elapsed_time) if elapsed_time > 0 else 1 # Prevent division by zero
58
+
59
+ # Header
60
+ header = f"{'Metric Name':<30} {'Lifetime (tokens/sec)':>25} {'Recently (tokens/sec)':>25}"
61
+ separator = "-" * len(header)
62
+ lines = [header, separator]
63
+
64
+ # Sort metrics alphabetically for consistency
65
+ for key in sorted(self.total_metrics.keys()):
66
+ total = self.total_metrics[key]
67
+ window = self.window_sum.get(key, 0)
68
+ total_rate = total / elapsed_time if elapsed_time > 0 else 0
69
+ window_rate = window / window_time if window_time > 0 else 0
70
+ line = f"{key:<20} {total_rate:>25.2f} {window_rate:>25.2f}"
71
+ lines.append(line)
72
+
73
+ return "\n".join(lines)
74
+
75
+
76
+ class WorkerTracker:
77
+ def __init__(self):
78
+ """
79
+ Initializes the WorkerTracker with a default dictionary.
80
+ Each worker ID maps to another dictionary that holds counts for each state.
81
+ """
82
+ # Mapping from worker_id to a dictionary of state counts
83
+ self.worker_status: Dict[int, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
84
+ self.lock = asyncio.Lock()
85
+
86
+ async def clear_work(self, worker_id: int):
87
+ async with self.lock:
88
+ self.worker_status[worker_id].clear()
89
+
90
+ async def track_work(self, worker_id: int, work_item_id: str, state: str):
91
+ """
92
+ Update the state count for a specific worker.
93
+
94
+ Args:
95
+ worker_id (int): The ID of the worker.
96
+ work_item_id (str): The unique identifier of the work item (unused in this implementation).
97
+ state (str): The state to increment for the work item.
98
+ """
99
+ async with self.lock:
100
+ self.worker_status[worker_id][state] += 1
101
+
102
+ async def get_status_table(self) -> str:
103
+ """
104
+ Generate a formatted table of the current status of all workers.
105
+
106
+ Returns:
107
+ str: A string representation of the workers' statuses.
108
+ """
109
+ async with self.lock:
110
+ # Determine all unique states across all workers
111
+ all_states: Set[str] = set()
112
+ for states in self.worker_status.values():
113
+ all_states.update(states.keys())
114
+ sorted_states: List[str] = sorted(all_states)
115
+
116
+ headers = ["Worker ID"] + sorted_states # type: ignore
117
+ rows = []
118
+ for worker_id, states in sorted(self.worker_status.items()):
119
+ row = [str(worker_id)]
120
+ for state in sorted_states:
121
+ count = states.get(state, 0)
122
+ row.append(str(count))
123
+ rows.append(row)
124
+
125
+ # Calculate column widths
126
+ col_widths = [len(header) for header in headers]
127
+ for row in rows:
128
+ for idx, cell in enumerate(row):
129
+ col_widths[idx] = max(col_widths[idx], len(cell))
130
+
131
+ # Create the table header
132
+ header_line = " | ".join(header.ljust(col_widths[idx]) for idx, header in enumerate(headers))
133
+ separator = "-+-".join("-" * col_widths[idx] for idx in range(len(headers)))
134
+
135
+ # Create the table rows
136
+ row_lines = [" | ".join(cell.ljust(col_widths[idx]) for idx, cell in enumerate(row)) for row in rows]
137
+
138
+ # Combine all parts
139
+ table = "\n".join([header_line, separator] + row_lines)
140
+ return table
141
+
142
+ def __str__(self):
143
+ """
144
+ String representation is not directly supported.
145
+ Use 'await get_status_table()' to retrieve the status table.
146
+ """
147
+ raise NotImplementedError("Use 'await get_status_table()' to get the status table.")
ocrflux/pipeline.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import atexit
4
+ import base64
5
+ import json
6
+ import logging
7
+ import shutil
8
+ import os
9
+ import copy
10
+ import random
11
+ import re
12
+ import sys
13
+ import time
14
+ from concurrent.futures.process import BrokenProcessPool
15
+ from io import BytesIO
16
+ from urllib.parse import urlparse
17
+
18
+ import httpx
19
+ from huggingface_hub import snapshot_download
20
+ from PIL import Image
21
+ from pypdf import PdfReader
22
+ from tqdm import tqdm
23
+
24
+ from ocrflux.check import (
25
+ check_poppler_version,
26
+ check_vllm_version,
27
+ check_torch_gpu_available,
28
+ )
29
+ from ocrflux.image_utils import get_page_image, is_image
30
+ from ocrflux.table_format import trans_markdown_text
31
+ from ocrflux.metrics import MetricsKeeper, WorkerTracker
32
+ from ocrflux.prompts import PageResponse, build_page_to_markdown_prompt, build_element_merge_detect_prompt, build_html_table_merge_prompt
33
+ from ocrflux.work_queue import LocalWorkQueue, WorkQueue
34
+
35
+ # Initialize logger
36
+ logger = logging.getLogger(__name__)
37
+ logger.setLevel(logging.DEBUG)
38
+ logger.propagate = False
39
+
40
+ vllm_logger = logging.getLogger("vllm")
41
+ vllm_logger.propagate = False
42
+
43
+ file_handler = logging.FileHandler("OCRFlux-debug.log", mode="a")
44
+ file_handler.setLevel(logging.DEBUG)
45
+ file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
46
+
47
+ console_handler = logging.StreamHandler()
48
+ console_handler.setLevel(logging.INFO)
49
+ console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
50
+
51
+ # Add handlers to the logger
52
+ logger.addHandler(file_handler)
53
+ logger.addHandler(console_handler)
54
+ vllm_logger.addHandler(file_handler)
55
+
56
+ # Quiet logs from pypdf
57
+ logging.getLogger("pypdf").setLevel(logging.ERROR)
58
+
59
+ # Global variables for token statistics
60
+ metrics = MetricsKeeper(window=60 * 5)
61
+ tracker = WorkerTracker()
62
+
63
+ def build_page_to_markdown_query(args, pdf_path: str, page_number: int, target_longest_image_dim: int, image_rotation: int = 0) -> dict:
64
+ assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
65
+
66
+ image = get_page_image(pdf_path, page_number, target_longest_image_dim=target_longest_image_dim, image_rotation=image_rotation)
67
+ buffered = BytesIO()
68
+ image.save(buffered, format="PNG")
69
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
70
+
71
+ return {
72
+ "model": args.model,
73
+ "messages": [
74
+ {
75
+ "role": "user",
76
+ "content": [
77
+ {"type": "text", "text": build_page_to_markdown_prompt()},
78
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
79
+ ],
80
+ }
81
+ ],
82
+ "temperature": 0.0,
83
+ }
84
+
85
+ def build_element_merge_detect_query(args,text_list_1,text_list_2) -> dict:
86
+ image = Image.new('RGB', (28, 28), color='black')
87
+
88
+ buffered = BytesIO()
89
+ image.save(buffered, format="PNG")
90
+
91
+ image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
92
+
93
+ return {
94
+ "model": args.model,
95
+ "messages": [
96
+ {
97
+ "role": "user",
98
+ "content": [
99
+ {"type": "text", "text": build_element_merge_detect_prompt(text_list_1,text_list_2)},
100
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
101
+ ],
102
+ }
103
+ ],
104
+ "temperature": 0.0,
105
+ }
106
+
107
+ def build_html_table_merge_query(args,text_1,text_2) -> dict:
108
+ image = Image.new('RGB', (28, 28), color='black')
109
+
110
+ buffered = BytesIO()
111
+ image.save(buffered, format="PNG")
112
+
113
+ image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
114
+
115
+ return {
116
+ "model": args.model,
117
+ "messages": [
118
+ {
119
+ "role": "user",
120
+ "content": [
121
+ {"type": "text", "text": build_html_table_merge_prompt(text_1,text_2)},
122
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
123
+ ],
124
+ }
125
+ ],
126
+ "temperature": 0.0,
127
+ }
128
+
129
+ # Manual simple implementation of HTTP Post
130
+ # It feels strange perhaps, but httpx and aiohttp are very complex beasts
131
+ # Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
132
+ # that at the scale of 100M+ requests, that they deadlock in different strange ways
133
+ async def apost(url, json_data):
134
+ parsed_url = urlparse(url)
135
+ host = parsed_url.hostname
136
+ port = parsed_url.port or 80
137
+ path = parsed_url.path or "/"
138
+
139
+ writer = None
140
+ try:
141
+ reader, writer = await asyncio.open_connection(host, port)
142
+
143
+ json_payload = json.dumps(json_data)
144
+ request = (
145
+ f"POST {path} HTTP/1.1\r\n"
146
+ f"Host: {host}\r\n"
147
+ f"Content-Type: application/json\r\n"
148
+ f"Content-Length: {len(json_payload)}\r\n"
149
+ f"Connection: close\r\n\r\n"
150
+ f"{json_payload}"
151
+ )
152
+ writer.write(request.encode())
153
+ await writer.drain()
154
+
155
+ # Read status line
156
+ status_line = await reader.readline()
157
+ if not status_line:
158
+ raise ConnectionError("No response from server")
159
+ status_parts = status_line.decode().strip().split(" ", 2)
160
+ if len(status_parts) < 2:
161
+ raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
162
+ status_code = int(status_parts[1])
163
+
164
+ # Read headers
165
+ headers = {}
166
+ while True:
167
+ line = await reader.readline()
168
+ if line in (b"\r\n", b"\n", b""):
169
+ break
170
+ key, _, value = line.decode().partition(":")
171
+ headers[key.strip().lower()] = value.strip()
172
+
173
+ # Read response body
174
+ if "content-length" in headers:
175
+ body_length = int(headers["content-length"])
176
+ response_body = await reader.readexactly(body_length)
177
+ else:
178
+ raise ConnectionError("Anything other than fixed content length responses are not implemented yet")
179
+
180
+ return status_code, response_body
181
+ except Exception as e:
182
+ # Pass through errors
183
+ raise e
184
+ finally:
185
+ # But just make sure to close the socket on your way out
186
+ if writer is not None:
187
+ try:
188
+ writer.close()
189
+ await writer.wait_closed()
190
+ except:
191
+ pass
192
+
193
+ async def process_task(args, worker_id, task_name, task_args):
194
+ COMPLETION_URL = f"http://localhost:{args.port}/v1/chat/completions"
195
+ MAX_RETRIES = args.max_page_retries
196
+ TEMPERATURE_BY_ATTEMPT = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
197
+ exponential_backoffs = 0
198
+ local_image_rotation = 0
199
+ attempt = 0
200
+ await tracker.track_work(worker_id, f"{worker_id}", "started")
201
+ while attempt < MAX_RETRIES:
202
+ if task_name == 'page_to_markdown':
203
+ pdf_path,page_number = task_args
204
+ query = build_page_to_markdown_query(args, pdf_path, page_number, args.target_longest_image_dim, image_rotation=local_image_rotation)
205
+ elif task_name == 'element_merge_detect':
206
+ text_list_1,text_list_2 = task_args
207
+ query = build_element_merge_detect_query(args, text_list_1, text_list_2)
208
+ elif task_name == 'html_table_merge':
209
+ table_1,table_2 = task_args
210
+ query = build_html_table_merge_query(args, table_1, table_2)
211
+ query["temperature"] = TEMPERATURE_BY_ATTEMPT[
212
+ min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
213
+ ] # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
214
+
215
+ try:
216
+ status_code, response_body = await apost(COMPLETION_URL, json_data=query)
217
+
218
+ if status_code == 400:
219
+ raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response")
220
+ elif status_code == 500:
221
+ raise ValueError(f"Got InternalServerError from server: {response_body}, skipping this response")
222
+ elif status_code != 200:
223
+ raise ValueError(f"Error http status {status_code}")
224
+
225
+ base_response_data = json.loads(response_body)
226
+
227
+ metrics.add_metrics(
228
+ vllm_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
229
+ vllm_output_tokens=base_response_data["usage"].get("completion_tokens", 0),
230
+ )
231
+
232
+ response_content = base_response_data["choices"][0]["message"]["content"]
233
+ if task_name == 'page_to_markdown':
234
+ model_response_json = json.loads(response_content)
235
+ page_response = PageResponse(**model_response_json)
236
+ if not page_response.is_rotation_valid and attempt < MAX_RETRIES - 1:
237
+ local_image_rotation = page_response.rotation_correction
238
+ raise ValueError(f"invalid_page rotation")
239
+ try:
240
+ return_data = trans_markdown_text(page_response.natural_text,"matrix2html")
241
+ except:
242
+ if attempt < MAX_RETRIES - 1:
243
+ raise
244
+ else:
245
+ return_data = page_response.natural_text.replace("<t>","").replace("<l>","").replace("<lt>","")
246
+
247
+ elif task_name == 'element_merge_detect':
248
+ pattern = r"\((\d+), (\d+)\)"
249
+ matches = re.findall(pattern, response_content)
250
+ return_data = [(int(x), int(y)) for x, y in matches]
251
+ elif task_name == 'html_table_merge':
252
+ if not (response_content.startswith("<table>") and response_content.endswith("</table>")):
253
+ raise ValueError("Response is not a table")
254
+ return_data = response_content
255
+ else:
256
+ raise ValueError(f"Unknown task_name {task_name}")
257
+
258
+ await tracker.track_work(worker_id, f"{worker_id}", "finished")
259
+ return return_data
260
+
261
+ except (ConnectionError, OSError, asyncio.TimeoutError) as e:
262
+ logger.warning(f"Client error on attempt {attempt} for {worker_id}: {type(e)} {e}")
263
+
264
+ # Now we want to do exponential backoff, and not count this as an actual page retry
265
+ # Page retrys are supposed to be for fixing bad results from the model, but actual requests to vllm
266
+ # are supposed to work. Probably this means that the server is just restarting
267
+ sleep_delay = 10 * (2**exponential_backoffs)
268
+ exponential_backoffs += 1
269
+ logger.info(f"Sleeping for {sleep_delay} seconds on {worker_id} to allow server restart")
270
+ await asyncio.sleep(sleep_delay)
271
+ except asyncio.CancelledError:
272
+ logger.info(f"Process {worker_id} cancelled")
273
+ await tracker.track_work(worker_id, f"{worker_id}", "cancelled")
274
+ raise
275
+ except json.JSONDecodeError as e:
276
+ logger.warning(f"JSON decode error on attempt {attempt} for {worker_id}: {e}")
277
+ attempt += 1
278
+ except ValueError as e:
279
+ logger.warning(f"ValueError on attempt {attempt} for {worker_id}: {type(e)} - {e}")
280
+ attempt += 1
281
+ except Exception as e:
282
+ logger.exception(f"Unexpected error on attempt {attempt} for {worker_id}: {type(e)} - {e}")
283
+ attempt += 1
284
+
285
+ logger.error(f"Failed to process {worker_id} after {MAX_RETRIES} attempts.")
286
+ await tracker.track_work(worker_id, f"{worker_id}", "errored")
287
+
288
+ return None
289
+
290
+ def postprocess_markdown_text(args, response_text, pdf_path, page_number):
291
+ text_list = response_text.split("\n\n")
292
+ new_text_list = []
293
+ for text in text_list:
294
+ if text.startswith("<Image>") and text.endswith("</Image>"):
295
+ pass
296
+ else:
297
+ new_text_list.append(text)
298
+ return "\n\n".join(new_text_list)
299
+
300
+ def bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result):
301
+ page_to_markdown_keys = list(page_to_markdown_result.keys())
302
+ element_merge_detect_keys = list(element_merge_detect_result.keys())
303
+ html_table_merge_keys = list(html_table_merge_result.keys())
304
+
305
+ for page_1,page_2,elem_idx_1,elem_idx_2 in sorted(html_table_merge_keys,key=lambda x: -x[0]):
306
+ page_to_markdown_result[page_1][elem_idx_1] = html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)]
307
+ page_to_markdown_result[page_2][elem_idx_2] = ''
308
+
309
+ for page_1,page_2 in sorted(element_merge_detect_keys,key=lambda x: -x[0]):
310
+ for elem_idx_1,elem_idx_2 in element_merge_detect_result[(page_1,page_2)]:
311
+ if len(page_to_markdown_result[page_1][elem_idx_1]) == 0 or page_to_markdown_result[page_1][elem_idx_1][-1] == '-' or ('\u4e00' <= page_to_markdown_result[page_1][elem_idx_1][-1] <= '\u9fff'):
312
+ page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + '' + page_to_markdown_result[page_2][elem_idx_2]
313
+ else:
314
+ page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + ' ' + page_to_markdown_result[page_2][elem_idx_2]
315
+ page_to_markdown_result[page_2][elem_idx_2] = ''
316
+
317
+ document_text_list = []
318
+ for page in page_to_markdown_keys:
319
+ page_text_list = [s for s in page_to_markdown_result[page] if s]
320
+ document_text_list += page_text_list
321
+ return "\n\n".join(document_text_list)
322
+
323
+ async def process_pdf(args, worker_id: int, pdf_path: str):
324
+ logger.info(f"Start process_pdf for {pdf_path}")
325
+ if pdf_path.lower().endswith(".pdf"):
326
+ try:
327
+ reader = PdfReader(pdf_path)
328
+ num_pages = reader.get_num_pages()
329
+ except:
330
+ logger.exception(f"Could not count number of pages for {pdf_path}, aborting document")
331
+ return None
332
+ else:
333
+ num_pages = 1
334
+
335
+ logger.info(f"Got {num_pages} pages to do for {pdf_path} in worker {worker_id}")
336
+
337
+ try:
338
+ tasks = []
339
+ results = []
340
+ async with asyncio.TaskGroup() as tg:
341
+ for page_num in range(1, num_pages + 1):
342
+ task = tg.create_task(process_task(args, worker_id, task_name='page_to_markdown', task_args=(pdf_path,page_num)))
343
+ tasks.append(task)
344
+
345
+ results = [task.result() for task in tasks]
346
+
347
+ fallback_pages = []
348
+ page_to_markdown_result = {}
349
+ page_pairs = []
350
+ for i,result in enumerate(results):
351
+ if result != None:
352
+ page_number = i+1
353
+ page_to_markdown_result[i+1] = postprocess_markdown_text(args,result,pdf_path,page_number).split("\n\n")
354
+ if page_number-1 in page_to_markdown_result.keys():
355
+ page_pairs.append((page_number-1,page_number))
356
+ else:
357
+ fallback_pages.append(i)
358
+
359
+ num_fallback_pages = len(fallback_pages)
360
+
361
+ if num_fallback_pages / num_pages > args.max_page_error_rate:
362
+ logger.error(
363
+ f"Document {pdf_path} has {num_fallback_pages} fallback pages out of {num_pages} exceeding max_page_error_rate of {args.max_page_error_rate}, discarding document."
364
+ )
365
+ return None
366
+ elif num_fallback_pages > 0:
367
+ logger.warning(
368
+ f"Document {pdf_path} processed with {num_fallback_pages} fallback pages out of {num_pages}."
369
+ )
370
+
371
+ if args.skip_cross_page_merge:
372
+ page_texts = {}
373
+ document_text_list = []
374
+ sorted_page_keys = sorted(list(page_to_markdown_result.keys()))
375
+ for page_number in sorted_page_keys:
376
+ page_texts[str(page_number-1)] = "\n\n".join(page_to_markdown_result[page_number])
377
+ document_text_list.append(page_texts[str(page_number-1)])
378
+ document_text = "\n\n".join(document_text_list)
379
+ return {
380
+ "orig_path": pdf_path,
381
+ "num_pages": num_pages,
382
+ "document_text": document_text,
383
+ "page_texts": page_texts,
384
+ "fallback_pages": fallback_pages,
385
+ }
386
+
387
+ tasks = []
388
+ results = []
389
+ async with asyncio.TaskGroup() as tg:
390
+ for page_1,page_2 in page_pairs:
391
+ task = tg.create_task(process_task(args, worker_id, task_name='element_merge_detect', task_args=(page_to_markdown_result[page_1], page_to_markdown_result[page_2])))
392
+ tasks.append(task)
393
+ results = [task.result() for task in tasks]
394
+
395
+ element_merge_detect_result = {}
396
+ table_pairs = []
397
+ for page_pair,result in zip(page_pairs,results):
398
+ if result != None:
399
+ page_1,page_2 = page_pair
400
+ element_merge_detect_result[(page_1,page_2)] = result
401
+ for elem_idx_1,elem_idx_2 in result:
402
+ text_1 = page_to_markdown_result[page_1][elem_idx_1]
403
+ text_2 = page_to_markdown_result[page_2][elem_idx_2]
404
+ if text_1.startswith("<table>") and text_1.endswith("</table>") and text_2.startswith("<table>") and text_2.endswith("</table>"):
405
+ table_pairs.append((page_1,page_2,elem_idx_1,elem_idx_2))
406
+
407
+ tmp_page_to_markdown_result = copy.deepcopy(page_to_markdown_result)
408
+ table_pairs = sorted(table_pairs,key=lambda x: -x[0])
409
+ html_table_merge_result = {}
410
+ i = 0
411
+ while i < len(table_pairs):
412
+ async with asyncio.TaskGroup() as tg:
413
+ tasks = []
414
+ ids_1 = []
415
+ ids_2 = []
416
+ page_1,page_2,elem_idx_1,elem_idx_2 = table_pairs[i]
417
+ task = tg.create_task(process_task(args, worker_id, task_name='html_table_merge', task_args=(tmp_page_to_markdown_result[page_1][elem_idx_1], tmp_page_to_markdown_result[page_2][elem_idx_2])))
418
+ tasks.append(task)
419
+ ids_1.append((page_1,elem_idx_1))
420
+ ids_2.append((page_2,elem_idx_2))
421
+ j = i + 1
422
+ while j < len(table_pairs):
423
+ page_1,page_2,elem_idx_1,elem_idx_2 = table_pairs[j]
424
+ if (page_2, elem_idx_2) not in ids_1:
425
+ task = tg.create_task(process_task(args, worker_id, task_name='html_table_merge', task_args=(tmp_page_to_markdown_result[page_1][elem_idx_1], tmp_page_to_markdown_result[page_2][elem_idx_2])))
426
+ tasks.append(task)
427
+ ids_1.append((page_1,elem_idx_1))
428
+ ids_2.append((page_2,elem_idx_2))
429
+ j = j + 1
430
+ else:
431
+ break
432
+
433
+ results = [task.result() for task in tasks]
434
+
435
+ for k,result in enumerate(results):
436
+ page_1,elem_idx_1 = ids_1[k]
437
+ page_2,elem_idx_2 = ids_2[k]
438
+ if result != None:
439
+ html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)] = result
440
+ tmp_page_to_markdown_result[page_1][elem_idx_1] = html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)]
441
+ i = j
442
+
443
+ page_texts = {}
444
+ for page_number in page_to_markdown_result.keys():
445
+ page_texts[str(page_number-1)] = "\n\n".join(page_to_markdown_result[page_number])
446
+
447
+ document_text = bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result)
448
+
449
+ return {
450
+ "orig_path": pdf_path,
451
+ "num_pages": num_pages,
452
+ "document_text": document_text,
453
+ "page_texts": page_texts,
454
+ "fallback_pages": fallback_pages,
455
+ }
456
+ except Exception as e:
457
+ # Check for ExceptionGroup with BrokenProcessPool
458
+ if isinstance(e, ExceptionGroup):
459
+ broken_pool, other = e.split(BrokenProcessPool)
460
+ if broken_pool is not None: # Found at least one BrokenProcessPool
461
+ logger.critical("Encountered BrokenProcessPool, exiting process.")
462
+ sys.exit(1)
463
+
464
+ logger.exception(f"Exception in process_pdf for {pdf_path}: {e}")
465
+ return None
466
+
467
+ async def process_json(args, worker_id: int, json_path: str):
468
+ try:
469
+ json_data = json.load(open(json_path,'r'))
470
+ except:
471
+ logger.exception(f"Could not load {json_path}, aborting document")
472
+ try:
473
+ if args.task == 'merge_pages':
474
+ page_1 = json_data['page_1'].split("\n\n")
475
+ page_2 = json_data['page_2'].split("\n\n")
476
+ async with asyncio.TaskGroup() as tg:
477
+ task = tg.create_task(process_task(args, worker_id, task_name='element_merge_detect', task_args=(page_1, page_2)))
478
+ result = task.result()
479
+ return {
480
+ "orig_path": json_path,
481
+ "merge_pairs": result
482
+ }
483
+ elif args.task == 'merge_tables':
484
+ table_1 = json_data['table_1']
485
+ table_2 = json_data['table_2']
486
+ async with asyncio.TaskGroup() as tg:
487
+ task = tg.create_task(process_task(args, worker_id, task_name='html_table_merge', task_args=(table_1, table_2)))
488
+ result = task.result()
489
+ return {
490
+ "orig_path": json_path,
491
+ "merged_tables": result
492
+ }
493
+ else:
494
+ raise ValueError(f"Unknown task {args.task}")
495
+
496
+ except Exception as e:
497
+ # Check for ExceptionGroup with BrokenProcessPool
498
+ if isinstance(e, ExceptionGroup):
499
+ broken_pool, other = e.split(BrokenProcessPool)
500
+ if broken_pool is not None: # Found at least one BrokenProcessPool
501
+ logger.critical("Encountered BrokenProcessPool, exiting process.")
502
+ sys.exit(1)
503
+
504
+ logger.exception(f"Exception in process_json for {json_path}: {e}")
505
+ return None
506
+
507
+ async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
508
+ while True:
509
+ # Wait until allowed to proceed
510
+ await semaphore.acquire()
511
+
512
+ work_item = await work_queue.get_work()
513
+
514
+ if work_item is None:
515
+ logger.info(f"Worker {worker_id} exiting due to empty queue")
516
+ semaphore.release()
517
+ break
518
+
519
+ logger.info(f"Worker {worker_id} processing work item {work_item.hash}")
520
+ await tracker.clear_work(worker_id)
521
+
522
+ try:
523
+ async with asyncio.TaskGroup() as tg:
524
+ if args.task == 'pdf2markdown':
525
+ tasks = [tg.create_task(process_pdf(args, worker_id, pdf_path)) for pdf_path in work_item.work_paths]
526
+ elif args.task == 'merge_pages' or args.task == 'merge_tables':
527
+ tasks = [tg.create_task(process_json(args, worker_id, json_path)) for json_path in work_item.work_paths]
528
+ else:
529
+ raise ValueError(f"Unknown task {args.task}")
530
+
531
+ logger.info(f"Created all tasks for {work_item.hash}")
532
+
533
+ logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
534
+
535
+ results = []
536
+ for task in tasks:
537
+ try:
538
+ result = task.result()
539
+ except:
540
+ pass
541
+
542
+ if result is not None:
543
+ results.append(result)
544
+
545
+ logger.info(f"Got {len(results)} docs for {work_item.hash}")
546
+
547
+ output_final_path = os.path.join(args.workspace, "results", f"output_{work_item.hash}.jsonl")
548
+ with open(output_final_path, "w") as f:
549
+ for result in results:
550
+ f.write(json.dumps(result))
551
+ f.write("\n")
552
+
553
+ await work_queue.mark_done(work_item)
554
+ except Exception as e:
555
+ logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}")
556
+ finally:
557
+ semaphore.release()
558
+
559
+ async def vllm_server_task(args, semaphore):
560
+ model_name_or_path = args.model
561
+
562
+ cmd = [
563
+ "vllm",
564
+ "serve",
565
+ model_name_or_path,
566
+ "--port",
567
+ str(args.port),
568
+ "--max-model-len",
569
+ str(args.model_max_context),
570
+ "--gpu_memory_utilization",
571
+ str(0.8)
572
+ ]
573
+
574
+ proc = await asyncio.create_subprocess_exec(
575
+ *cmd,
576
+ stdout=asyncio.subprocess.PIPE,
577
+ stderr=asyncio.subprocess.PIPE,
578
+ )
579
+
580
+ # Ensure the subprocess is terminated on exit
581
+ def _kill_proc():
582
+ proc.terminate()
583
+
584
+ atexit.register(_kill_proc)
585
+
586
+ # Shared variables between tasks
587
+ last_running_req, last_queue_req = 0, 0
588
+ server_printed_ready_message = False
589
+ last_semaphore_release = time.time()
590
+
591
+ async def process_line(line):
592
+ nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
593
+ vllm_logger.info(line)
594
+
595
+ # if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
596
+ # can see any warnings/errors more easily
597
+ if not server_printed_ready_message:
598
+ logger.info(line)
599
+
600
+ if "Detected errors during sampling" in line:
601
+ logger.error("Cannot continue, sampling errors detected, model is probably corrupt")
602
+ sys.exit(1)
603
+
604
+ # TODO, need to trace down this issue in vllm itself, but it will otherwise cause the server to lock up
605
+ if "IndexError: list index out of range" in line:
606
+ logger.error("IndexError in model, restarting server")
607
+ proc.terminate()
608
+
609
+ if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
610
+ server_printed_ready_message = True
611
+ last_semaphore_release = time.time()
612
+
613
+ match = re.search(r"Running: (\d+)", line)
614
+ if match:
615
+ last_running_req = int(match.group(1))
616
+
617
+ match = re.search(r"(?:Waiting|Pending):\s*(\d+)", line)
618
+ if match:
619
+ last_queue_req = int(match.group(1))
620
+ logger.info(f"vllm running req: {last_running_req} queue req: {last_queue_req}")
621
+
622
+ async def read_stream(stream):
623
+ while True:
624
+ line = await stream.readline()
625
+ if not line:
626
+ break
627
+ try:
628
+ line = line.decode("utf-8").rstrip()
629
+ await process_line(line)
630
+ except Exception as ex:
631
+ logger.warning(f"Got {ex} when reading log line from inference server, skipping")
632
+
633
+ async def timeout_task():
634
+ nonlocal last_running_req, last_queue_req, last_semaphore_release
635
+ try:
636
+ while True:
637
+ await asyncio.sleep(1)
638
+ if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
639
+ semaphore.release()
640
+ last_semaphore_release = time.time()
641
+ logger.info("Semaphore released, allowing a worker to proceed.")
642
+ except asyncio.CancelledError:
643
+ pass # Clean up if the task is cancelled
644
+
645
+ # Start tasks to read stdout, stderr, and handle timeout logic
646
+ stdout_task = asyncio.create_task(read_stream(proc.stdout))
647
+ stderr_task = asyncio.create_task(read_stream(proc.stderr))
648
+ timeout_task = asyncio.create_task(timeout_task())
649
+
650
+ try:
651
+ await proc.wait()
652
+ except asyncio.CancelledError:
653
+ logger.info("Got cancellation request for VLLM server")
654
+ proc.terminate()
655
+ raise
656
+
657
+ timeout_task.cancel()
658
+ await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
659
+
660
+ async def vllm_server_host(args, semaphore):
661
+ MAX_RETRIES = 5
662
+ retry = 0
663
+
664
+ while retry < MAX_RETRIES:
665
+ await vllm_server_task(args, semaphore)
666
+ logger.warning("VLLM server task ended")
667
+ retry += 1
668
+
669
+ if retry >= MAX_RETRIES:
670
+ logger.error(f"Ended up starting the vllm server more than {retry} times, cancelling pipeline")
671
+ logger.error("")
672
+ logger.error("Please make sure vllm is installed according to the latest instructions here: https://docs.vllm.ai/start/install.html")
673
+ sys.exit(1)
674
+
675
+ async def vllm_server_ready(args):
676
+ max_attempts = 300
677
+ delay_sec = 1
678
+ url = f"http://localhost:{args.port}/v1/models"
679
+
680
+ for attempt in range(1, max_attempts + 1):
681
+ try:
682
+ async with httpx.AsyncClient() as session:
683
+ response = await session.get(url)
684
+
685
+ if response.status_code == 200:
686
+ logger.info("vllm server is ready.")
687
+ return
688
+ else:
689
+ logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
690
+ except Exception:
691
+ logger.warning(f"Attempt {attempt}: Please wait for vllm server to become ready...")
692
+
693
+ await asyncio.sleep(delay_sec)
694
+
695
+ raise Exception("vllm server did not become ready after waiting.")
696
+
697
+ async def download_model(model_name_or_path: str):
698
+ if os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path):
699
+ logger.info(f"Using local model path at '{model_name_or_path}'")
700
+ else:
701
+ logger.info(f"Downloading model with hugging face '{model_name_or_path}'")
702
+ snapshot_download(repo_id=model_name_or_path)
703
+
704
+ async def metrics_reporter(work_queue):
705
+ while True:
706
+ # Leading newlines preserve table formatting in logs
707
+ logger.info(f"Queue remaining: {work_queue.size}")
708
+ logger.info("\n" + str(metrics))
709
+ logger.info("\n" + str(await tracker.get_status_table()))
710
+ await asyncio.sleep(10)
711
+
712
+ async def main():
713
+ parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline")
714
+ parser.add_argument(
715
+ "workspace",
716
+ help="The filesystem path where work will be stored, can be a local folder",
717
+ )
718
+
719
+ parser.add_argument("--task", type=str, choices=['pdf2markdown','merge_pages','merge_tables'], default='pdf2markdown', help="task names, could be 'pdf2markdown', 'merge_pages' or 'merge_tables'")
720
+
721
+ parser.add_argument(
722
+ "--data",
723
+ nargs="*",
724
+ help="List of paths to files to process",
725
+ default=None,
726
+ )
727
+
728
+ parser.add_argument("--pages_per_group", type=int, default=500, help="Aiming for this many pdf pages per work item group")
729
+ parser.add_argument("--max_page_retries", type=int, default=8, help="Max number of times we will retry rendering a page")
730
+ parser.add_argument("--max_page_error_rate", type=float, default=0.004, help="Rate of allowable failed pages in a document, 1/250 by default")
731
+ parser.add_argument("--workers", type=int, default=8, help="Number of workers to run at a time")
732
+
733
+ # Model parameters
734
+ parser.add_argument(
735
+ "--model",
736
+ help="The path to the model",
737
+ default="ChatDOC/OCRFlux-3B",
738
+ )
739
+ parser.add_argument("--model_max_context", type=int, default=16384, help="Maximum context length that the model was fine tuned under")
740
+ parser.add_argument("--model_chat_template", type=str, default="qwen2-vl", help="Chat template to pass to vllm server")
741
+ parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1024)
742
+
743
+ parser.add_argument("--skip_cross_page_merge", action="store_true", help="Whether to skip cross-page merging")
744
+
745
+ parser.add_argument("--port", type=int, default=40078, help="Port to use for the VLLM server")
746
+ args = parser.parse_args()
747
+
748
+ if os.path.exists(args.workspace):
749
+ shutil.rmtree(args.workspace)
750
+
751
+ # We need poppler to load the initial pdfs, even if we are not processing them here
752
+ check_poppler_version()
753
+
754
+ work_queue = LocalWorkQueue(args.workspace)
755
+
756
+ if args.task == 'pdf2markdown':
757
+ pdf_work_paths = set()
758
+
759
+ for pdf_path in args.data:
760
+ if os.path.exists(pdf_path):
761
+ if pdf_path.lower().endswith(".pdf") and open(pdf_path, "rb").read(4) == b"%PDF":
762
+ logger.info(f"Loading file at {pdf_path} as PDF document")
763
+ pdf_work_paths.add(pdf_path)
764
+ elif is_image(pdf_path):
765
+ logger.info(f"Loading file at {pdf_path} as image document")
766
+ pdf_work_paths.add(pdf_path)
767
+ else:
768
+ raise ValueError(f"Unsupported file extension for {pdf_path}")
769
+ else:
770
+ raise ValueError(f"{pdf_path} does not exist")
771
+
772
+ logger.info(f"Found {len(pdf_work_paths):,} total pdf paths to add")
773
+
774
+ # Estimate average pages per pdf
775
+ sample_size = min(100, len(pdf_work_paths))
776
+ sampled_pdfs = random.sample(list(pdf_work_paths), sample_size)
777
+ page_counts = []
778
+
779
+ for pdf_path in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
780
+ try:
781
+ if pdf_path.lower().endswith(".pdf"):
782
+ reader = PdfReader(pdf_path)
783
+ page_counts.append(len(reader.pages))
784
+ else:
785
+ page_counts.append(1)
786
+ except Exception as e:
787
+ logger.warning(f"Failed to read {pdf_path}: {e}")
788
+
789
+ if page_counts:
790
+ avg_pages_per_pdf = sum(page_counts) / len(page_counts)
791
+ else:
792
+ logger.warning("Could not read any PDFs to estimate average page count.")
793
+ avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
794
+
795
+ items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf))
796
+ logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
797
+
798
+ # Now call populate_queue
799
+ await work_queue.populate_queue(pdf_work_paths, items_per_group)
800
+ elif args.task == 'merge_pages' or args.task == 'merge_tables':
801
+ json_work_paths = set()
802
+ for json_path in args.data:
803
+ if os.path.exists(json_path):
804
+ if json_path.lower().endswith(".json"):
805
+ json_work_paths.add(json_path)
806
+ elif json_path.lower().endswith(".txt"):
807
+ logger.info(f"Loading file at {json_path} as list of paths")
808
+ with open(json_path, "r") as f:
809
+ json_work_paths |= set(filter(None, (line.strip() for line in f)))
810
+ else:
811
+ raise ValueError(f"Unsupported file extension for {json_path}")
812
+ else:
813
+ raise ValueError(f"{json_path} does not exist")
814
+
815
+ # Now call populate_queue
816
+ await work_queue.populate_queue(json_work_paths, args.pages_per_group)
817
+
818
+
819
+ # If you get this far, then you are doing inference and need a GPU
820
+ check_vllm_version()
821
+ check_torch_gpu_available()
822
+
823
+ logger.info(f"Starting pipeline with PID {os.getpid()}")
824
+
825
+ # Download the model before you do anything else
826
+ await download_model(args.model)
827
+
828
+ # Initialize the work queue
829
+ qsize = await work_queue.initialize_queue()
830
+
831
+ if qsize == 0:
832
+ logger.info("No work to do, exiting")
833
+ return
834
+ # Create a semaphore to control worker access
835
+ # We only allow one worker to move forward with requests, until the server has no more requests in its queue
836
+ # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
837
+ # As soon as one worker is no longer saturating the gpu, the next one can start sending requests
838
+ semaphore = asyncio.Semaphore(1)
839
+
840
+ vllm_server = asyncio.create_task(vllm_server_host(args, semaphore))
841
+
842
+ await vllm_server_ready(args)
843
+
844
+ metrics_task = asyncio.create_task(metrics_reporter(work_queue))
845
+
846
+ # Create worker tasks to process the queue concurrently.
847
+ worker_tasks = []
848
+ for i in range(args.workers):
849
+ task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
850
+ worker_tasks.append(task)
851
+
852
+ # Wait for all worker tasks to finish
853
+ await asyncio.gather(*worker_tasks)
854
+
855
+ vllm_server.cancel()
856
+ metrics_task.cancel()
857
+ logger.info("Work done")
858
+
859
+
860
+ if __name__ == "__main__":
861
+ asyncio.run(main())
ocrflux/prompts.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ @dataclass(frozen=True)
6
+ class PageResponse:
7
+ primary_language: Optional[str]
8
+ is_rotation_valid: bool
9
+ rotation_correction: int
10
+ is_table: bool
11
+ is_diagram: bool
12
+ natural_text: Optional[str]
13
+
14
+ def __post_init__(self):
15
+ # Validate rotation_correction is one of the allowed values
16
+ if self.rotation_correction not in {0, 90, 180, 270}:
17
+ raise ValueError("rotation_correction must be one of [0, 90, 180, 270].")
18
+
19
+ # Type checks
20
+ if not isinstance(self.primary_language, (str, type(None))):
21
+ raise TypeError("primary_language must be of type Optional[str].")
22
+ if not isinstance(self.is_rotation_valid, bool):
23
+ raise TypeError("is_rotation_valid must be of type bool.")
24
+ if not isinstance(self.rotation_correction, int):
25
+ raise TypeError("rotation_correction must be of type int.")
26
+ if not isinstance(self.is_table, bool):
27
+ raise TypeError("is_table must be of type bool.")
28
+ if not isinstance(self.is_diagram, bool):
29
+ raise TypeError("is_diagram must be of type bool.")
30
+ if not isinstance(self.natural_text, (str, type(None))):
31
+ raise TypeError("natural_text must be of type Optional[str].")
32
+
33
+ def build_element_merge_detect_prompt(text_list_1,text_list_2) -> str:
34
+ task = '''Below are two consecutive pages in Markdown format, where each element of them is numbered. Identify pairs of elements which should be merged across the two pages, such as text paragraphs or tables that span across the two pages. Return pairs as [(element_index_of_page1, element_index_of_page2), ...] or [] if no elements should be merged.\n'''
35
+ task += "Previous page:\n"
36
+ for i,text in enumerate(text_list_1):
37
+ task += f"{i}. {text}\n\n"
38
+ task += "Next page:\n"
39
+ for i,text in enumerate(text_list_2):
40
+ task += f"{i}. {text}\n\n"
41
+ return task
42
+
43
+ def build_html_table_merge_prompt(table1,table2) -> str:
44
+ return (
45
+ f"Below are two tables in HTML format, merge them into one table in HTML format.\n"
46
+ f"TABLE 1:\n"
47
+ f"{table1}\n"
48
+ f"TABLE 2:\n"
49
+ f"{table2}\n"
50
+ )
51
+
52
+ def build_page_to_markdown_prompt() -> str:
53
+ return (
54
+ f"Below is the image of one page of a document. "
55
+ f"Just return the plain text representation of this document as if you were reading it naturally.\n"
56
+ f"ALL tables should be presented in HTML format.\n"
57
+ f"If there are images or figures in the page, present them as \"<Image>(left,top),(right,bottom)</Image>\", (left,top,right,bottom) are the coordinates of the top-left and bottom-right corners of the image or figure.\n"
58
+ f"Present all titles and headings as H1 headings.\n"
59
+ f"Do not hallucinate.\n"
60
+ )
ocrflux/table_format.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from bs4 import BeautifulSoup
3
+ import re
4
+
5
+ def is_html_table(text):
6
+ soup = BeautifulSoup(text, "html.parser")
7
+ return soup.find('table') is not None
8
+
9
+ def table_matrix2html(matrix_table):
10
+ soup = BeautifulSoup(matrix_table, 'html.parser')
11
+ table = soup.find('table')
12
+ rownum = 0
13
+ colnum = 0
14
+ cell_dict = {}
15
+ rid = 0
16
+ for tr in table.find_all('tr'):
17
+ cid = 0
18
+ for td in tr.find_all('td'):
19
+ if td.find('l'):
20
+ cell_dict[(rid, cid)] = '<l>'
21
+ elif td.find('t'):
22
+ cell_dict[(rid, cid)] = '<t>'
23
+ elif td.find('lt'):
24
+ cell_dict[(rid, cid)] = '<lt>'
25
+ else:
26
+ text = td.get_text(strip=True)
27
+ cell_dict[(rid, cid)] = text
28
+ cid += 1
29
+ if colnum == 0:
30
+ colnum = cid
31
+ elif cid != colnum:
32
+ raise Exception('colnum not match')
33
+ rid += 1
34
+ rownum = rid
35
+ html_table = ['<table>']
36
+ for rid in range(rownum):
37
+ html_table.append('<tr>')
38
+ for cid in range(colnum):
39
+ if (rid, cid) not in cell_dict.keys():
40
+ continue
41
+ text = cell_dict[(rid, cid)]
42
+ if text == '<l>' or text == '<t>' or text == '<lt>':
43
+ raise Exception('cell not match')
44
+ rowspan = 1
45
+ colspan = 1
46
+ for r in range(rid+1, rownum):
47
+ if (r, cid) in cell_dict.keys() and cell_dict[(r, cid)] == '<t>':
48
+ rowspan += 1
49
+ del cell_dict[(r, cid)]
50
+ else:
51
+ break
52
+ for c in range(cid+1, colnum):
53
+ if (rid, c) in cell_dict.keys() and cell_dict[(rid, c)] == '<l>':
54
+ colspan += 1
55
+ del cell_dict[(rid, c)]
56
+ else:
57
+ break
58
+ for r in range(rid+1, rid+rowspan):
59
+ for c in range(cid+1, cid+colspan):
60
+ if cell_dict[(r, c)] != '<lt>':
61
+ raise Exception('cell not match')
62
+ del cell_dict[(r, c)]
63
+ attr = ''
64
+ if rowspan > 1:
65
+ attr += ' rowspan="{}"'.format(rowspan)
66
+ if colspan > 1:
67
+ attr += ' colspan="{}"'.format(colspan)
68
+ html_table.append("<td{}>{}</td>".format(attr, text))
69
+ html_table.append('</tr>')
70
+ html_table.append('</table>')
71
+ return "".join(html_table)
72
+
73
+ def table_html2matrix(html_table):
74
+ soup = BeautifulSoup(html_table, 'html.parser')
75
+ table = soup.find('table')
76
+ rownum = len(table.find_all('tr'))
77
+ colnum = 0
78
+ tr = table.find_all('tr')[0]
79
+ for td in tr.find_all('td'):
80
+ colnum += td.get('colspan', 1)
81
+ matrix = [[None for _ in range(colnum)] for _ in range(rownum)]
82
+
83
+ rid = 0
84
+ for tr in table.find_all('tr'):
85
+ cid = 0
86
+ for td in tr.find_all('td'):
87
+ for c in range(cid, colnum):
88
+ if matrix[rid][c] is None:
89
+ break
90
+ cid = c
91
+ rowspan = td.get('rowspan', 1)
92
+ colspan = td.get('colspan', 1)
93
+ cell_text = td.get_text(strip=True)
94
+ for r in range(rid,rid+rowspan):
95
+ if r >= rownum:
96
+ raise Exception('rownum not match')
97
+ for c in range(cid,cid+colspan):
98
+ if c >= colnum:
99
+ raise Exception('colnum not match')
100
+ if matrix[r][c] is not None:
101
+ raise Exception('cell not match')
102
+ if r == rid and c == cid:
103
+ matrix[r][c] = cell_text
104
+ elif r == rid:
105
+ matrix[r][c] = '<l>'
106
+ elif c == cid:
107
+ matrix[r][c] = '<t>'
108
+ else:
109
+ matrix[r][c] = '<lt>'
110
+ cid += colspan
111
+ rid += 1
112
+
113
+ matrix_table = ['<table>']
114
+ for rid in range(rownum):
115
+ matrix_table.append('<tr>')
116
+ for cid in range(colnum):
117
+ matrix_table.append('<td>')
118
+ cell_text = matrix[rid][cid]
119
+ matrix_table.append(cell_text)
120
+ matrix_table.append('</td>')
121
+ matrix_table.append('</tr>')
122
+ matrix_table.append('</table>')
123
+ return "".join(matrix_table)
124
+
125
+ trans_func = {
126
+ "html2matrix": table_html2matrix,
127
+ "matrix2html": table_matrix2html,
128
+ }
129
+
130
+ def trans_markdown_text(markdown_text,trans_type):
131
+ if markdown_text == None:
132
+ return None
133
+ text_list = markdown_text.split('\n\n')
134
+ for i,text in enumerate(text_list):
135
+ if is_html_table(text):
136
+ text_list[i] = trans_func[trans_type](text)
137
+ return "\n\n".join(text_list)
138
+
139
+
140
+
141
+
142
+
143
+
ocrflux/work_queue.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import asyncio
3
+ import datetime
4
+ import hashlib
5
+ import logging
6
+ import os
7
+ import random
8
+ from asyncio import Queue
9
+ from dataclasses import dataclass
10
+ from typing import Any, List, Optional
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class WorkItem:
17
+ """Represents a single work item in the queue"""
18
+
19
+ hash: str
20
+ work_paths: List[str]
21
+
22
+
23
+ class WorkQueue(abc.ABC):
24
+ """
25
+ Base class defining the interface for a work queue.
26
+ """
27
+
28
+ @abc.abstractmethod
29
+ async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
30
+ """
31
+ Add new items to the work queue. The specifics will vary depending on
32
+ whether this is a local or S3-backed queue.
33
+
34
+ Args:
35
+ work_paths: Each individual path that we will process over
36
+ items_per_group: Number of items to group together in a single work item
37
+ """
38
+ pass
39
+
40
+ @abc.abstractmethod
41
+ async def initialize_queue(self) -> int:
42
+ """
43
+ Load the work queue from the relevant store (local or remote)
44
+ and initialize it for processing.
45
+
46
+ For example, this might remove already completed work items and randomize
47
+ the order before adding them to an internal queue.
48
+ """
49
+ pass
50
+
51
+ @abc.abstractmethod
52
+ async def is_completed(self, work_hash: str) -> bool:
53
+ """
54
+ Check if a work item has been completed.
55
+
56
+ Args:
57
+ work_hash: Hash of the work item to check
58
+
59
+ Returns:
60
+ True if the work is completed, False otherwise
61
+ """
62
+ pass
63
+
64
+ @abc.abstractmethod
65
+ async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
66
+ """
67
+ Get the next available work item that isn't completed or locked.
68
+
69
+ Args:
70
+ worker_lock_timeout_secs: Number of seconds before considering
71
+ a worker lock stale (default 30 mins)
72
+
73
+ Returns:
74
+ WorkItem if work is available, None if queue is empty
75
+ """
76
+ pass
77
+
78
+ @abc.abstractmethod
79
+ async def mark_done(self, work_item: WorkItem) -> None:
80
+ """
81
+ Mark a work item as done by removing its lock file
82
+ or performing any other cleanup.
83
+
84
+ Args:
85
+ work_item: The WorkItem to mark as done
86
+ """
87
+ pass
88
+
89
+ @property
90
+ @abc.abstractmethod
91
+ def size(self) -> int:
92
+ """Get current size of work queue"""
93
+ pass
94
+
95
+ @staticmethod
96
+ def _compute_workgroup_hash(work_paths: List[str]) -> str:
97
+ """
98
+ Compute a deterministic hash for a group of paths.
99
+
100
+ Args:
101
+ work_paths: List of paths (local or S3)
102
+
103
+ Returns:
104
+ SHA1 hash of the sorted paths
105
+ """
106
+ sha1 = hashlib.sha1()
107
+ for path in sorted(work_paths):
108
+ sha1.update(path.encode("utf-8"))
109
+ return sha1.hexdigest()
110
+
111
+
112
+ # --------------------------------------------------------------------------------------
113
+ # Local Helpers for reading/writing the index CSV (compressed with zstd) to disk
114
+ # --------------------------------------------------------------------------------------
115
+
116
+ try:
117
+ import zstandard
118
+ except ImportError:
119
+ zstandard = None
120
+
121
+
122
+ def download_zstd_csv_local(local_path: str) -> List[str]:
123
+ """
124
+ Download a zstd-compressed CSV from a local path.
125
+ If the file doesn't exist, returns an empty list.
126
+ """
127
+ if not os.path.exists(local_path):
128
+ return []
129
+
130
+ if not zstandard:
131
+ raise RuntimeError("zstandard package is required for local zstd CSV operations.")
132
+
133
+ with open(local_path, "rb") as f:
134
+ dctx = zstandard.ZstdDecompressor()
135
+ data = dctx.decompress(f.read())
136
+ lines = data.decode("utf-8").splitlines()
137
+ return lines
138
+
139
+
140
+ def upload_zstd_csv_local(local_path: str, lines: List[str]) -> None:
141
+ """
142
+ Upload a zstd-compressed CSV to a local path.
143
+ """
144
+ if not zstandard:
145
+ raise RuntimeError("zstandard package is required for local zstd CSV operations.")
146
+
147
+ data = "\n".join(lines).encode("utf-8")
148
+ cctx = zstandard.ZstdCompressor()
149
+ compressed_data = cctx.compress(data)
150
+
151
+ # Ensure parent directories exist
152
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
153
+
154
+ with open(local_path, "wb") as f:
155
+ f.write(compressed_data)
156
+
157
+
158
+ # --------------------------------------------------------------------------------------
159
+ # LocalWorkQueue Implementation
160
+ # --------------------------------------------------------------------------------------
161
+
162
+
163
+ class LocalWorkQueue(WorkQueue):
164
+ """
165
+ A local in-memory and on-disk WorkQueue implementation, which uses
166
+ a local workspace directory to store the queue index, lock files,
167
+ and completed results for persistent resumption across process restarts.
168
+ """
169
+
170
+ def __init__(self, workspace_path: str):
171
+ """
172
+ Initialize the local work queue.
173
+
174
+ Args:
175
+ workspace_path: Local directory path where the queue index,
176
+ results, and locks are stored.
177
+ """
178
+ self.workspace_path = os.path.abspath(workspace_path)
179
+ os.makedirs(self.workspace_path, exist_ok=True)
180
+
181
+ # Local index file (compressed)
182
+ self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
183
+
184
+ # Output directory for completed tasks
185
+ self._results_dir = os.path.join(self.workspace_path, "results")
186
+ os.makedirs(self._results_dir, exist_ok=True)
187
+
188
+ # Directory for lock files
189
+ self._locks_dir = os.path.join(self.workspace_path, "worker_locks")
190
+ os.makedirs(self._locks_dir, exist_ok=True)
191
+
192
+ # Internal queue
193
+ self._queue: Queue[Any] = Queue()
194
+
195
+ async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
196
+ """
197
+ Add new items to the work queue (local version).
198
+
199
+ Args:
200
+ work_paths: Each individual path (local in this context)
201
+ that we will process over
202
+ items_per_group: Number of items to group together in a single work item
203
+ """
204
+ # Treat them as local paths, but keep variable name for consistency
205
+ all_paths = set(work_paths)
206
+ logger.info(f"Found {len(all_paths):,} total paths")
207
+
208
+ # Load existing work groups from local index
209
+ existing_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
210
+ existing_groups = {}
211
+ for line in existing_lines:
212
+ if line.strip():
213
+ parts = line.strip().split(",")
214
+ group_hash = parts[0]
215
+ group_paths = parts[1:]
216
+ existing_groups[group_hash] = group_paths
217
+
218
+ existing_path_set = {p for paths in existing_groups.values() for p in paths}
219
+ new_paths = all_paths - existing_path_set
220
+ logger.info(f"{len(new_paths):,} new paths to add to the workspace")
221
+
222
+ if not new_paths:
223
+ return
224
+
225
+ # Create new work groups
226
+ new_groups = []
227
+ current_group = []
228
+ for path in sorted(new_paths):
229
+ current_group.append(path)
230
+ if len(current_group) == items_per_group:
231
+ group_hash = self._compute_workgroup_hash(current_group)
232
+ new_groups.append((group_hash, current_group))
233
+ current_group = []
234
+ if current_group:
235
+ group_hash = self._compute_workgroup_hash(current_group)
236
+ new_groups.append((group_hash, current_group))
237
+
238
+ logger.info(f"Created {len(new_groups):,} new work groups")
239
+
240
+ # Combine and save updated work groups
241
+ combined_groups = existing_groups.copy()
242
+ for group_hash, group_paths in new_groups:
243
+ combined_groups[group_hash] = group_paths
244
+
245
+ combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
246
+
247
+ if new_groups:
248
+ # Write the combined data back to disk in zstd CSV format
249
+ await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines)
250
+
251
+ async def initialize_queue(self) -> int:
252
+ """
253
+ Load the work queue from the local index file and initialize it for processing.
254
+ Removes already completed work items and randomizes the order.
255
+ """
256
+ # 1) Read the index
257
+ work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
258
+ work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()}
259
+
260
+ # 2) Determine which items are completed by scanning local results/*.jsonl
261
+ if not os.path.isdir(self._results_dir):
262
+ os.makedirs(self._results_dir, exist_ok=True)
263
+ done_work_items = [f for f in os.listdir(self._results_dir) if f.startswith("output_") and f.endswith(".jsonl")]
264
+ done_work_hashes = {fn[len("output_") : -len(".jsonl")] for fn in done_work_items}
265
+
266
+ # 3) Filter out completed items
267
+ remaining_work_hashes = set(work_queue) - done_work_hashes
268
+ remaining_items = [WorkItem(hash=hash_, work_paths=work_queue[hash_]) for hash_ in remaining_work_hashes]
269
+ random.shuffle(remaining_items)
270
+
271
+ # 4) Initialize our in-memory queue
272
+ self._queue = asyncio.Queue()
273
+ for item in remaining_items:
274
+ await self._queue.put(item)
275
+
276
+ logger.info(f"Initialized local queue with {self._queue.qsize()} work items")
277
+
278
+ return self._queue.qsize()
279
+
280
+ async def is_completed(self, work_hash: str) -> bool:
281
+ """
282
+ Check if a work item has been completed locally by seeing if
283
+ output_{work_hash}.jsonl is present in the results directory.
284
+
285
+ Args:
286
+ work_hash: Hash of the work item to check
287
+ """
288
+ output_file = os.path.join(self._results_dir, f"output_{work_hash}.jsonl")
289
+ return os.path.exists(output_file)
290
+
291
+ async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
292
+ """
293
+ Get the next available work item that isn't completed or locked.
294
+
295
+ Args:
296
+ worker_lock_timeout_secs: Number of seconds before considering
297
+ a worker lock stale (default 30 mins)
298
+
299
+ Returns:
300
+ WorkItem if work is available, None if queue is empty
301
+ """
302
+ while True:
303
+ try:
304
+ work_item = self._queue.get_nowait()
305
+ except asyncio.QueueEmpty:
306
+ return None
307
+
308
+ # Check if work is already completed
309
+ if await self.is_completed(work_item.hash):
310
+ logger.debug(f"Work item {work_item.hash} already completed, skipping")
311
+ self._queue.task_done()
312
+ continue
313
+
314
+ # Check for worker lock
315
+ lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
316
+ if os.path.exists(lock_file):
317
+ # Check modification time
318
+ mtime = datetime.datetime.fromtimestamp(os.path.getmtime(lock_file), datetime.timezone.utc)
319
+ if (datetime.datetime.now(datetime.timezone.utc) - mtime).total_seconds() > worker_lock_timeout_secs:
320
+ # Lock is stale, we can take this work
321
+ logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
322
+ else:
323
+ # Lock is active, skip this work
324
+ logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
325
+ self._queue.task_done()
326
+ continue
327
+
328
+ # Create our lock file (touch an empty file)
329
+ try:
330
+ with open(lock_file, "wb") as f:
331
+ f.write(b"")
332
+ except Exception as e:
333
+ logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
334
+ self._queue.task_done()
335
+ continue
336
+
337
+ return work_item
338
+
339
+ async def mark_done(self, work_item: WorkItem) -> None:
340
+ """
341
+ Mark a work item as done by removing its lock file.
342
+
343
+ Args:
344
+ work_item: The WorkItem to mark as done
345
+ """
346
+ lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
347
+ if os.path.exists(lock_file):
348
+ try:
349
+ os.remove(lock_file)
350
+ except Exception as e:
351
+ logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
352
+ self._queue.task_done()
353
+
354
+ @property
355
+ def size(self) -> int:
356
+ """Get current size of local work queue"""
357
+ return self._queue.qsize()
pyproject.toml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "ocrflux"
7
+ description = "Fast, efficient, and high quality OCR powered by open visual language models"
8
+ version = "0.1.0"
9
+ readme = "README.md"
10
+ classifiers = [
11
+ "Intended Audience :: Science/Research",
12
+ "Development Status :: 3 - Alpha",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ "Programming Language :: Python :: 3",
15
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
16
+ ]
17
+ authors = [
18
+ {name = "Yu Tang", email = "[email protected]"}
19
+ ]
20
+ requires-python = ">=3.11"
21
+ dependencies = [
22
+ "cached-path",
23
+ "smart_open",
24
+ "pypdf>=5.2.0",
25
+ "pypdfium2",
26
+ "cryptography",
27
+ "lingua-language-detector",
28
+ "Pillow",
29
+ "ftfy",
30
+ "bleach",
31
+ "markdown2",
32
+ "filelock",
33
+ "orjson",
34
+ "requests",
35
+ "zstandard",
36
+ "boto3",
37
+ "httpx",
38
+ "torch>=2.5.1",
39
+ "transformers==4.50.0",
40
+ "vllm==0.7.3",
41
+ "img2pdf",
42
+ "nltk",
43
+ "bs4",
44
+ "distance",
45
+ "apted",
46
+ "gradio",
47
+ "gradio_pdf",
48
+ ]
49
+ license = {file = "LICENSE"}
50
+
51
+ [project.urls]
52
+ Homepage = "https://github.com/chatdoc-com/OCRFlux"
53
+ Repository = "https://github.com/chatdoc-com/OCRFlux"
54
+
55
+ [tool.setuptools.packages.find]
56
+ exclude = [
57
+ "*.tests",
58
+ "*.tests.*",
59
+ "tests.*",
60
+ "tests",
61
+ "docs*",
62
+ "scripts*",
63
+ "images*"
64
+ ]
65
+
66
+ [tool.setuptools]
67
+ include-package-data = true
68
+
69
+ [tool.setuptools.package-data]
70
+ ocrflux = [
71
+ "py.typed",
72
+ ]
73
+
74
+ [tool.black]
75
+ line-length = 79