ZennyKenny commited on
Commit
905cbc7
Β·
verified Β·
1 Parent(s): 3e741fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -246
app.py CHANGED
@@ -1,259 +1,229 @@
1
- import spaces
2
- import json
3
- import math
4
- import os
5
- import traceback
6
- from io import BytesIO
7
- from typing import Any, Dict, List, Optional, Tuple
8
- import re
9
-
10
- import fitz # PyMuPDF
11
  import gradio as gr
12
- import requests
13
  import torch
14
- from huggingface_hub import snapshot_download
15
- from PIL import Image, ImageDraw, ImageFont
16
- from qwen_vl_utils import process_vision_info
17
- from transformers import AutoModelForCausalLM, AutoProcessor
18
-
19
- # Constants
20
- MIN_PIXELS = 3136
21
- MAX_PIXELS = 11289600
22
- IMAGE_FACTOR = 28
23
-
24
- # Prompts
25
- prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
26
-
27
- 1. Bbox format: [x1, y1, x2, y2]
28
- 2. Layout Categories: ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
29
- 3. Text Extraction & Formatting Rules:
30
- - Picture: Omit text.
31
- - Formula: Format as LaTeX.
32
- - Table: Format as HTML.
33
- - Others: Format as Markdown.
34
- 4. Output must be the original text with no translation, sorted in human reading order.
35
- 5. Final output: single JSON object.
36
- """
37
-
38
- # Utility functions
39
- def round_by_factor(number: int, factor: int) -> int:
40
- return round(number / factor) * factor
41
-
42
- def smart_resize(height: int, width: int, factor: int = 28,
43
- min_pixels: int = 3136, max_pixels: int = 11289600):
44
- if max(height, width) / min(height, width) > 200:
45
- raise ValueError("absolute aspect ratio must be smaller than 200")
46
- h_bar = max(factor, round_by_factor(height, factor))
47
- w_bar = max(factor, round_by_factor(width, factor))
48
- if h_bar * w_bar > max_pixels:
49
- beta = math.sqrt((height * width) / max_pixels)
50
- h_bar = round_by_factor(height / beta, factor)
51
- w_bar = round_by_factor(width / beta, factor)
52
- elif h_bar * w_bar < min_pixels:
53
- beta = math.sqrt(min_pixels / (height * width))
54
- h_bar = round_by_factor(height * beta, factor)
55
- w_bar = round_by_factor(width * beta, factor)
56
- return h_bar, w_bar
57
 
58
- def fetch_image(image_input, min_pixels=None, max_pixels=None):
59
- if isinstance(image_input, str):
60
- if image_input.startswith(("http://", "https://")):
61
- response = requests.get(image_input)
62
- image = Image.open(BytesIO(response.content)).convert('RGB')
63
- else:
64
- image = Image.open(image_input).convert('RGB')
65
- elif isinstance(image_input, Image.Image):
66
- image = image_input.convert('RGB')
67
- else:
68
- raise ValueError(f"Invalid image input type: {type(image_input)}")
69
- if min_pixels is not None or max_pixels is not None:
70
- min_pixels = min_pixels or MIN_PIXELS
71
- max_pixels = max_pixels or MAX_PIXELS
72
- height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR,
73
- min_pixels=min_pixels, max_pixels=max_pixels)
74
- image = image.resize((width, height), Image.LANCZOS)
75
- return image
76
 
77
- def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
78
- images = []
79
- try:
80
- pdf_document = fitz.open(pdf_path)
81
- for page_num in range(len(pdf_document)):
82
- page = pdf_document.load_page(page_num)
83
- mat = fitz.Matrix(2.0, 2.0)
84
- pix = page.get_pixmap(matrix=mat)
85
- img_data = pix.tobytes("ppm")
86
- image = Image.open(BytesIO(img_data)).convert('RGB')
87
- images.append(image)
88
- pdf_document.close()
89
- except Exception as e:
90
- print(f"Error loading PDF: {e}")
91
- return images
92
 
93
- def is_arabic_text(text: str) -> bool:
94
- if not text:
95
- return False
96
- header_pattern = r'^#{1,6}\s+(.+)$'
97
- paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
98
- content_text = []
99
- for line in text.split('\n'):
100
- line = line.strip()
101
- if not line:
102
- continue
103
- header_match = re.match(header_pattern, line, re.MULTILINE)
104
- if header_match:
105
- content_text.append(header_match.group(1))
106
- continue
107
- if re.match(paragraph_pattern, line, re.MULTILINE):
108
- content_text.append(line)
109
- if not content_text:
110
- return False
111
- combined_text = ' '.join(content_text)
112
- arabic_chars = sum(1 for c in combined_text if '\u0600' <= c <= '\u06FF' or '\u0750' <= c <= '\u077F' or '\u08A0' <= c <= '\u08FF')
113
- total_chars = sum(1 for c in combined_text if c.isalpha())
114
- return total_chars > 0 and (arabic_chars / total_chars) > 0.5
115
 
116
- def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key='text') -> str:
117
- import base64
118
- markdown_lines = []
119
- try:
120
- sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
121
- for item in sorted_items:
122
- category = item.get('category', '')
123
- text = item.get(text_key, '')
124
- if category == 'Picture':
125
- markdown_lines.append("![Image](Image detected)\n")
126
- elif not text:
127
- continue
128
- elif category == 'Title':
129
- markdown_lines.append(f"# {text}\n")
130
- elif category == 'Section-header':
131
- markdown_lines.append(f"## {text}\n")
132
- elif category == 'Text':
133
- markdown_lines.append(f"{text}\n")
134
- elif category == 'List-item':
135
- markdown_lines.append(f"- {text}\n")
136
- elif category == 'Table':
137
- markdown_lines.append(f"{text}\n")
138
- elif category == 'Formula':
139
- markdown_lines.append(f"$$\n{text}\n$$\n")
140
- elif category == 'Caption':
141
- markdown_lines.append(f"*{text}*\n")
142
- elif category == 'Footnote':
143
- markdown_lines.append(f"^{text}^\n")
144
- except Exception as e:
145
- print(f"Error converting to markdown: {e}")
146
- return str(layout_data)
147
- return "\n".join(markdown_lines)
148
-
149
- # Model
150
- model_id = "rednote-hilab/dots.ocr"
151
- model_path = "./models/dots-ocr-local"
152
- snapshot_download(repo_id=model_id, local_dir=model_path, local_dir_use_symlinks=False)
153
- model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2",
154
- torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
155
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
156
- device = "cuda" if torch.cuda.is_available() else "cpu"
157
-
158
- # State
159
- pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []}
160
 
161
  @spaces.GPU()
162
- def inference(image: Image.Image, prompt: str, max_new_tokens=24000) -> str:
163
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
164
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
165
- image_inputs, video_inputs = process_vision_info(messages)
166
- inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(device)
167
- with torch.no_grad():
168
- generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1)
169
- generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
170
- output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
171
- return output_text[0] if output_text else ""
172
-
173
- def process_image(image: Image.Image, min_pixels=None, max_pixels=None):
174
- if min_pixels is not None or max_pixels is not None:
175
- image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
176
- raw_output = inference(image, prompt)
177
- try:
178
- layout_data = json.loads(raw_output)
179
- return layoutjson2md(image, layout_data), layout_data
180
- except json.JSONDecodeError:
181
- return raw_output, None
 
 
 
 
 
 
 
182
 
183
- def load_file_for_preview(file_path: str):
184
- global pdf_cache
185
- if not file_path or not os.path.exists(file_path):
186
- return None, "No file selected"
187
- ext = os.path.splitext(file_path)[1].lower()
188
- if ext == '.pdf':
189
- images = load_images_from_pdf(file_path)
190
- pdf_cache.update({"images": images, "current_page": 0, "total_pages": len(images),
191
- "file_type": "pdf", "is_parsed": False, "results": []})
192
- return images[0], f"Page 1 / {len(images)}"
193
- else:
194
- img = Image.open(file_path).convert('RGB')
195
- pdf_cache.update({"images": [img], "current_page": 0, "total_pages": 1,
196
- "file_type": "image", "is_parsed": False, "results": []})
197
- return img, "Page 1 / 1"
198
-
199
- def turn_page(direction: str):
200
- global pdf_cache
201
- if not pdf_cache["images"]:
202
- return None, '<div class="page-info">No file loaded</div>', "No results yet"
203
- if direction == "prev":
204
- pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
205
- elif direction == "next":
206
- pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
207
- idx = pdf_cache["current_page"]
208
- img = pdf_cache["images"][idx]
209
- page_info_html = f'<div class="page-info">Page {idx + 1} / {pdf_cache["total_pages"]}</div>'
210
- markdown_content = "Page not processed yet"
211
- if pdf_cache["is_parsed"] and idx < len(pdf_cache["results"]):
212
- markdown_content = pdf_cache["results"][idx]
213
- if is_arabic_text(markdown_content):
214
- markdown_content = gr.update(value=markdown_content, rtl=True)
215
- return img, page_info_html, markdown_content
216
 
217
- def create_gradio_interface():
218
- css = ".page-info {text-align: center;padding: 8px 16px;border-radius: 20px;font-weight: bold;margin: 10px 0;}"
219
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
220
- gr.HTML("<h1 style='text-align:center'>πŸ” Dot-OCR - Extracted Content Only</h1>")
221
- with gr.Row():
222
- with gr.Column(scale=1):
223
- file_input = gr.File(label="Upload Image or PDF", file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"], type="filepath")
224
- image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
225
- with gr.Row():
226
- prev_page_btn = gr.Button("β—€ Previous")
227
- page_info = gr.HTML('<div class="page-info">No file loaded</div>')
228
- next_page_btn = gr.Button("Next β–Ά")
229
- process_btn = gr.Button("πŸš€ Process Document", variant="primary")
230
- clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
231
- with gr.Column(scale=2):
232
- markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", height=500)
233
- file_input.change(load_file_for_preview, inputs=file_input, outputs=[image_preview, page_info])
234
- prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output])
235
- next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output])
236
- process_btn.click(lambda f: _process_document(f), inputs=file_input, outputs=[markdown_output])
237
- clear_btn.click(lambda: (None, None, '<div class="page-info">No file loaded</div>', "Click 'Process Document' to see extracted content..."),
238
- outputs=[file_input, image_preview, page_info, markdown_output])
239
- return demo
240
 
241
- def _process_document(file_path):
242
- global pdf_cache
243
- if not file_path:
244
- return "Please upload a file first."
245
- img, _ = load_file_for_preview(file_path)
246
- results = []
247
- for page_img in pdf_cache["images"]:
248
- md_content, _ = process_image(page_img)
249
- results.append(md_content)
250
- pdf_cache["results"] = results
251
- pdf_cache["is_parsed"] = True
252
- combined_md = "\n\n---\n\n".join(results)
253
- if is_arabic_text(combined_md):
254
- return gr.update(value=combined_md, rtl=True)
255
- return combined_md
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  if __name__ == "__main__":
258
- demo = create_gradio_interface()
259
- demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
  import torch
4
+ from gradio_pdf import PDF
5
+ from pdf2image import convert_from_path
6
+ from PIL import Image
7
+ from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ model_path = "nanonets/Nanonets-OCR-s"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Load model once at startup
12
+ print("Loading Nanonets OCR model...")
13
+ model = AutoModelForImageTextToText.from_pretrained(
14
+ model_path,
15
+ torch_dtype="auto",
16
+ device_map="auto",
17
+ attn_implementation="flash_attention_2",
18
+ )
19
+ model.eval()
 
 
 
 
 
 
20
 
21
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
22
+ processor = AutoProcessor.from_pretrained(model_path)
23
+ print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  @spaces.GPU()
27
+ def ocr_image_gradio(image, max_tokens=4096):
28
+ """Process image through Nanonets OCR model for Gradio interface"""
29
+ if image is None:
30
+ return "Please upload an image."
31
+
32
+ prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and β˜‘ for check boxes."""
33
+
34
+ # Convert PIL image if needed
35
+ if not isinstance(image, Image.Image):
36
+ image = Image.fromarray(image)
37
+
38
+ messages = [
39
+ {"role": "system", "content": "You are a helpful assistant."},
40
+ {
41
+ "role": "user",
42
+ "content": [
43
+ {"type": "image", "image": image},
44
+ {"type": "text", "text": prompt},
45
+ ],
46
+ },
47
+ ]
48
+
49
+ text = processor.apply_chat_template(
50
+ messages, tokenize=False, add_generation_prompt=True
51
+ )
52
+ inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")
53
+ inputs = inputs.to(model.device)
54
 
55
+ with torch.no_grad():
56
+ output_ids = model.generate(
57
+ **inputs,
58
+ max_new_tokens=max_tokens,
59
+ do_sample=False,
60
+ repetition_penalty=1.25,
61
+ )
62
+ generated_ids = [
63
+ output_ids[len(input_ids) :]
64
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids)
65
+ ]
66
+
67
+ output_text = processor.batch_decode(
68
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
69
+ )
70
+ return output_text[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ @spaces.GPU()
74
+ def ocr_pdf_gradio(pdf_path, max_tokens=4096, progress=gr.Progress()):
75
+ """Process each page of a PDF through Nanonets OCR model"""
76
+ if pdf_path is None:
77
+ return "Please upload a PDF file."
78
+
79
+ # Convert PDF to images
80
+ progress(0, desc="Converting PDF to images...")
81
+ pdf_images = convert_from_path(pdf_path)
82
+
83
+ # Process each page
84
+ all_text = []
85
+ total_pages = len(pdf_images)
86
+
87
+ for i, image in enumerate(pdf_images):
88
+ progress(
89
+ (i + 1) / total_pages, desc=f"Processing page {i + 1}/{total_pages}..."
90
+ )
91
+ page_text = ocr_image_gradio(image, max_tokens)
92
+ all_text.append(f"--- PAGE {i + 1} ---\n{page_text}\n")
93
+
94
+ # Combine results
95
+ combined_text = "\n".join(all_text)
96
+ return combined_text
97
+
98
+
99
+ # Create Gradio interface
100
+ with gr.Blocks(title="Nanonets OCR Demo") as demo:
101
+ # Replace simple markdown with styled HTML header that includes resources
102
+ gr.HTML("""
103
+ <div class="title" style="text-align: center">
104
+ <h1>πŸ” Nanonets OCR - Document Text Extraction</h1>
105
+ <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
106
+ A state-of-the-art image-to-markdown OCR model for intelligent document processing
107
+ </p>
108
+ <div style="display: flex; justify-content: center; gap: 20px; margin: 15px 0;">
109
+ <a href="https://huggingface.co/nanonets/Nanonets-OCR-s" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
110
+ πŸ“š Hugging Face Model
111
+ </a>
112
+ <a href="https://nanonets.com/research/nanonets-ocr-s/" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
113
+ πŸ“ Release Blog
114
+ </a>
115
+ <a href="https://github.com/NanoNets/docext" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
116
+ πŸ’» GitHub Repository
117
+ </a>
118
+ </div>
119
+ </div>
120
+ """)
121
+
122
+ with gr.Tabs() as tabs:
123
+ # Image tab
124
+ with gr.TabItem("Image OCR"):
125
+ with gr.Row():
126
+ with gr.Column(scale=1):
127
+ image_input = gr.Image(
128
+ label="Upload Document Image", type="pil", height=400
129
+ )
130
+ image_max_tokens = gr.Slider(
131
+ minimum=1024,
132
+ maximum=8192,
133
+ value=4096,
134
+ step=512,
135
+ label="Max Tokens",
136
+ info="Maximum number of tokens to generate",
137
+ )
138
+ image_extract_btn = gr.Button(
139
+ "Extract Text", variant="primary", size="lg"
140
+ )
141
+
142
+ with gr.Column(scale=2):
143
+ image_output_text = gr.Textbox(
144
+ label="Extracted Text",
145
+ lines=20,
146
+ show_copy_button=True,
147
+ placeholder="Extracted text will appear here...",
148
+ )
149
+
150
+ # PDF tab
151
+ with gr.TabItem("PDF OCR"):
152
+ with gr.Row():
153
+ with gr.Column(scale=1):
154
+ pdf_input = PDF(label="Upload PDF Document", height=400)
155
+ pdf_max_tokens = gr.Slider(
156
+ minimum=1024,
157
+ maximum=8192,
158
+ value=4096,
159
+ step=512,
160
+ label="Max Tokens per Page",
161
+ info="Maximum number of tokens to generate for each page",
162
+ )
163
+ pdf_extract_btn = gr.Button(
164
+ "Extract PDF Text", variant="primary", size="lg"
165
+ )
166
+
167
+ with gr.Column(scale=2):
168
+ pdf_output_text = gr.Textbox(
169
+ label="Extracted Text (All Pages)",
170
+ lines=20,
171
+ show_copy_button=True,
172
+ placeholder="Extracted text will appear here...",
173
+ )
174
+
175
+ # Event handlers for Image tab
176
+ image_extract_btn.click(
177
+ fn=ocr_image_gradio,
178
+ inputs=[image_input, image_max_tokens],
179
+ outputs=image_output_text,
180
+ show_progress=True,
181
+ )
182
+
183
+ image_input.change(
184
+ fn=ocr_image_gradio,
185
+ inputs=[image_input, image_max_tokens],
186
+ outputs=image_output_text,
187
+ show_progress=True,
188
+ )
189
+
190
+ # Event handlers for PDF tab
191
+ pdf_extract_btn.click(
192
+ fn=ocr_pdf_gradio,
193
+ inputs=[pdf_input, pdf_max_tokens],
194
+ outputs=pdf_output_text,
195
+ show_progress=True,
196
+ )
197
+
198
+ # Add model information section
199
+ with gr.Accordion("About Nanonets-OCR-s", open=False):
200
+ gr.Markdown("""
201
+ ## Nanonets-OCR-s
202
+
203
+ Nanonets-OCR-s is a powerful, state-of-the-art image-to-markdown OCR model that goes far beyond traditional text extraction.
204
+ It transforms documents into structured markdown with intelligent content recognition and semantic tagging, making it ideal
205
+ for downstream processing by Large Language Models (LLMs).
206
+
207
+ ### Key Features
208
+
209
+ - **LaTeX Equation Recognition**: Automatically converts mathematical equations and formulas into properly formatted LaTeX syntax.
210
+ It distinguishes between inline ($...$) and display ($$...$$) equations.
211
+
212
+ - **Intelligent Image Description**: Describes images within documents using structured `<img>` tags, making them digestible
213
+ for LLM processing. It can describe various image types, including logos, charts, graphs and so on, detailing their content,
214
+ style, and context.
215
+
216
+ - **Signature Detection & Isolation**: Identifies and isolates signatures from other text, outputting them within a `<signature>` tag.
217
+ This is crucial for processing legal and business documents.
218
+
219
+ - **Watermark Extraction**: Detects and extracts watermark text from documents, placing it within a `<watermark>` tag.
220
+
221
+ - **Smart Checkbox Handling**: Converts form checkboxes and radio buttons into standardized Unicode symbols (☐, β˜‘, β˜’)
222
+ for consistent and reliable processing.
223
+
224
+ - **Complex Table Extraction**: Accurately extracts complex tables from documents and converts them into both markdown
225
+ and HTML table formats.
226
+ """)
227
 
228
  if __name__ == "__main__":
229
+ demo.queue().launch(ssr_mode=False)