ZennyKenny commited on
Commit
0f28e05
·
verified ·
1 Parent(s): bb8ebf0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -94
app.py CHANGED
@@ -1,109 +1,259 @@
1
- # app.py
2
- import spaces # must be first
3
-
 
4
  import traceback
5
  from io import BytesIO
6
- from typing import Tuple
 
7
 
 
8
  import gradio as gr
9
  import requests
10
  import torch
11
  from huggingface_hub import snapshot_download
12
- from PIL import Image
13
  from qwen_vl_utils import process_vision_info
14
- from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
15
-
16
- # --- Config ---
17
- OCR_REPO = "rednote-hilab/dots.ocr"
18
- OCR_LOCAL = "./models/dots-ocr-local"
19
- CONVERT_REPO = "ZennyKenny/oss-20b-prereform-to-modern-ru-merged"
20
-
21
- SYSTEM_MSG = (
22
- "You convert Russian text from pre-1918 orthography to modern Russian spelling. "
23
- "Keep wording and punctuation; change only orthography."
24
- )
25
- OCR_PROMPT = (
26
- "Extract the original text from this image as plain text. "
27
- "Keep the reading order. Do not translate. Do not add extra formatting."
28
- )
29
-
30
- # --- Snapshot OCR locally (same technique as the working Space) ---
31
- snapshot_download(repo_id=OCR_REPO, local_dir=OCR_LOCAL, local_dir_use_symlinks=False)
32
-
33
- # --- Load models at module scope (after spaces import) ---
34
- # Expecting flash-attn to be available & ABI-compatible now
35
- ocr_model = AutoModelForCausalLM.from_pretrained(
36
- OCR_LOCAL,
37
- attn_implementation="flash_attention_2",
38
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
39
- device_map="auto",
40
- trust_remote_code=True,
41
- )
42
- ocr_processor = AutoProcessor.from_pretrained(OCR_LOCAL, trust_remote_code=True)
43
-
44
- tok = AutoTokenizer.from_pretrained(CONVERT_REPO, use_fast=True)
45
- conv_model = AutoModelForCausalLM.from_pretrained(
46
- CONVERT_REPO,
47
- device_map="auto",
48
- torch_dtype="auto",
49
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
 
 
 
 
 
 
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
52
 
53
- def fetch_image(x) -> Image.Image:
54
- if isinstance(x, Image.Image):
55
- return x.convert("RGB")
56
- if isinstance(x, str):
57
- if x.startswith(("http://", "https://")):
58
- r = requests.get(x, timeout=30); r.raise_for_status()
59
- return Image.open(BytesIO(r.content)).convert("RGB")
60
- return Image.open(x).convert("RGB")
61
- raise ValueError(f"Unsupported input: {type(x)}")
62
-
63
- def run_ocr(img: Image.Image) -> str:
64
- messages = [{"role":"user","content":[{"type":"image","image":img},{"type":"text","text":OCR_PROMPT}]}]
65
- text = ocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
66
  image_inputs, video_inputs = process_vision_info(messages)
67
- inputs = ocr_processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(device)
68
- with torch.no_grad():
69
- out = ocr_model.generate(**inputs, max_new_tokens=4096, do_sample=False, temperature=0.0)
70
- trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out)]
71
- s = ocr_processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
72
- return s.strip()
73
-
74
- def convert_pre_to_modern(txt: str) -> str:
75
- messages = [{"role":"system","content":SYSTEM_MSG},{"role":"user","content":txt}]
76
- prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
77
- inputs = tok([prompt], return_tensors="pt").to(conv_model.device)
78
  with torch.no_grad():
79
- gen = conv_model.generate(**inputs, max_new_tokens=1024, do_sample=False, temperature=0.0, repetition_penalty=1.05)
80
- gen_only = gen[0][inputs["input_ids"].shape[1]:]
81
- return tok.decode(gen_only, skip_special_tokens=True).strip()
 
82
 
83
- @spaces.GPU()
84
- def transcribe_and_convert(image_in) -> Tuple[Image.Image, str, str, str]:
 
 
85
  try:
86
- img = fetch_image(image_in)
87
- ocr_text = run_ocr(img)
88
- modern = convert_pre_to_modern(ocr_text)
89
- md = f"```text\n{modern}\n```"
90
- return img, ocr_text, modern, md
91
- except Exception as e:
92
- traceback.print_exc()
93
- return None, "", "", f"Error: {e}"
94
-
95
- with gr.Blocks(title="Pre-reform → Modern Russian (OCR + Conversion)") as demo:
96
- gr.Markdown("Upload an image with pre-1918 Russian → OCR (dots.ocr) → convert to modern Russian.")
97
- with gr.Row():
98
- with gr.Column(scale=1):
99
- img_in = gr.Image(type="pil", label="Upload image")
100
- btn = gr.Button("Transcribe & Convert", variant="primary")
101
- with gr.Column(scale=2):
102
- with gr.Row():
103
- img_out = gr.Image(label="Preview", interactive=False)
104
- ocr_box = gr.Textbox(label="Transcribed (pre-reform)", lines=14)
105
- modern_box = gr.Textbox(label="Modern Russian", lines=14)
106
- md_box = gr.Markdown(label="Markdown block")
107
- btn.click(transcribe_and_convert, [img_in], [img_out, ocr_box, modern_box, md_box], api_name="transcribe_convert")
108
-
109
- demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import json
3
+ import math
4
+ import os
5
  import traceback
6
  from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ import re
9
 
10
+ import fitz # PyMuPDF
11
  import gradio as gr
12
  import requests
13
  import torch
14
  from huggingface_hub import snapshot_download
15
+ from PIL import Image, ImageDraw, ImageFont
16
  from qwen_vl_utils import process_vision_info
17
+ from transformers import AutoModelForCausalLM, AutoProcessor
18
+
19
+ # Constants
20
+ MIN_PIXELS = 3136
21
+ MAX_PIXELS = 11289600
22
+ IMAGE_FACTOR = 28
23
+
24
+ # Prompts
25
+ prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
26
+
27
+ 1. Bbox format: [x1, y1, x2, y2]
28
+ 2. Layout Categories: ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
29
+ 3. Text Extraction & Formatting Rules:
30
+ - Picture: Omit text.
31
+ - Formula: Format as LaTeX.
32
+ - Table: Format as HTML.
33
+ - Others: Format as Markdown.
34
+ 4. Output must be the original text with no translation, sorted in human reading order.
35
+ 5. Final output: single JSON object.
36
+ """
37
+
38
+ # Utility functions
39
+ def round_by_factor(number: int, factor: int) -> int:
40
+ return round(number / factor) * factor
41
+
42
+ def smart_resize(height: int, width: int, factor: int = 28,
43
+ min_pixels: int = 3136, max_pixels: int = 11289600):
44
+ if max(height, width) / min(height, width) > 200:
45
+ raise ValueError("absolute aspect ratio must be smaller than 200")
46
+ h_bar = max(factor, round_by_factor(height, factor))
47
+ w_bar = max(factor, round_by_factor(width, factor))
48
+ if h_bar * w_bar > max_pixels:
49
+ beta = math.sqrt((height * width) / max_pixels)
50
+ h_bar = round_by_factor(height / beta, factor)
51
+ w_bar = round_by_factor(width / beta, factor)
52
+ elif h_bar * w_bar < min_pixels:
53
+ beta = math.sqrt(min_pixels / (height * width))
54
+ h_bar = round_by_factor(height * beta, factor)
55
+ w_bar = round_by_factor(width * beta, factor)
56
+ return h_bar, w_bar
57
+
58
+ def fetch_image(image_input, min_pixels=None, max_pixels=None):
59
+ if isinstance(image_input, str):
60
+ if image_input.startswith(("http://", "https://")):
61
+ response = requests.get(image_input)
62
+ image = Image.open(BytesIO(response.content)).convert('RGB')
63
+ else:
64
+ image = Image.open(image_input).convert('RGB')
65
+ elif isinstance(image_input, Image.Image):
66
+ image = image_input.convert('RGB')
67
+ else:
68
+ raise ValueError(f"Invalid image input type: {type(image_input)}")
69
+ if min_pixels is not None or max_pixels is not None:
70
+ min_pixels = min_pixels or MIN_PIXELS
71
+ max_pixels = max_pixels or MAX_PIXELS
72
+ height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR,
73
+ min_pixels=min_pixels, max_pixels=max_pixels)
74
+ image = image.resize((width, height), Image.LANCZOS)
75
+ return image
76
+
77
+ def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
78
+ images = []
79
+ try:
80
+ pdf_document = fitz.open(pdf_path)
81
+ for page_num in range(len(pdf_document)):
82
+ page = pdf_document.load_page(page_num)
83
+ mat = fitz.Matrix(2.0, 2.0)
84
+ pix = page.get_pixmap(matrix=mat)
85
+ img_data = pix.tobytes("ppm")
86
+ image = Image.open(BytesIO(img_data)).convert('RGB')
87
+ images.append(image)
88
+ pdf_document.close()
89
+ except Exception as e:
90
+ print(f"Error loading PDF: {e}")
91
+ return images
92
+
93
+ def is_arabic_text(text: str) -> bool:
94
+ if not text:
95
+ return False
96
+ header_pattern = r'^#{1,6}\s+(.+)$'
97
+ paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
98
+ content_text = []
99
+ for line in text.split('\n'):
100
+ line = line.strip()
101
+ if not line:
102
+ continue
103
+ header_match = re.match(header_pattern, line, re.MULTILINE)
104
+ if header_match:
105
+ content_text.append(header_match.group(1))
106
+ continue
107
+ if re.match(paragraph_pattern, line, re.MULTILINE):
108
+ content_text.append(line)
109
+ if not content_text:
110
+ return False
111
+ combined_text = ' '.join(content_text)
112
+ arabic_chars = sum(1 for c in combined_text if '\u0600' <= c <= '\u06FF' or '\u0750' <= c <= '\u077F' or '\u08A0' <= c <= '\u08FF')
113
+ total_chars = sum(1 for c in combined_text if c.isalpha())
114
+ return total_chars > 0 and (arabic_chars / total_chars) > 0.5
115
+
116
+ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key='text') -> str:
117
+ import base64
118
+ markdown_lines = []
119
+ try:
120
+ sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
121
+ for item in sorted_items:
122
+ category = item.get('category', '')
123
+ text = item.get(text_key, '')
124
+ if category == 'Picture':
125
+ markdown_lines.append("![Image](Image detected)\n")
126
+ elif not text:
127
+ continue
128
+ elif category == 'Title':
129
+ markdown_lines.append(f"# {text}\n")
130
+ elif category == 'Section-header':
131
+ markdown_lines.append(f"## {text}\n")
132
+ elif category == 'Text':
133
+ markdown_lines.append(f"{text}\n")
134
+ elif category == 'List-item':
135
+ markdown_lines.append(f"- {text}\n")
136
+ elif category == 'Table':
137
+ markdown_lines.append(f"{text}\n")
138
+ elif category == 'Formula':
139
+ markdown_lines.append(f"$$\n{text}\n$$\n")
140
+ elif category == 'Caption':
141
+ markdown_lines.append(f"*{text}*\n")
142
+ elif category == 'Footnote':
143
+ markdown_lines.append(f"^{text}^\n")
144
+ except Exception as e:
145
+ print(f"Error converting to markdown: {e}")
146
+ return str(layout_data)
147
+ return "\n".join(markdown_lines)
148
 
149
+ # Model
150
+ model_id = "rednote-hilab/dots.ocr"
151
+ model_path = "./models/dots-ocr-local"
152
+ snapshot_download(repo_id=model_id, local_dir=model_path, local_dir_use_symlinks=False)
153
+ model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2",
154
+ torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
155
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
156
  device = "cuda" if torch.cuda.is_available() else "cpu"
157
 
158
+ # State
159
+ pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []}
160
+
161
+ @spaces.GPU()
162
+ def inference(image: Image.Image, prompt: str, max_new_tokens=24000) -> str:
163
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
164
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
 
165
  image_inputs, video_inputs = process_vision_info(messages)
166
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
167
  with torch.no_grad():
168
+ generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1)
169
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
170
+ output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
171
+ return output_text[0] if output_text else ""
172
 
173
+ def process_image(image: Image.Image, min_pixels=None, max_pixels=None):
174
+ if min_pixels is not None or max_pixels is not None:
175
+ image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
176
+ raw_output = inference(image, prompt)
177
  try:
178
+ layout_data = json.loads(raw_output)
179
+ return layoutjson2md(image, layout_data), layout_data
180
+ except json.JSONDecodeError:
181
+ return raw_output, None
182
+
183
+ def load_file_for_preview(file_path: str):
184
+ global pdf_cache
185
+ if not file_path or not os.path.exists(file_path):
186
+ return None, "No file selected"
187
+ ext = os.path.splitext(file_path)[1].lower()
188
+ if ext == '.pdf':
189
+ images = load_images_from_pdf(file_path)
190
+ pdf_cache.update({"images": images, "current_page": 0, "total_pages": len(images),
191
+ "file_type": "pdf", "is_parsed": False, "results": []})
192
+ return images[0], f"Page 1 / {len(images)}"
193
+ else:
194
+ img = Image.open(file_path).convert('RGB')
195
+ pdf_cache.update({"images": [img], "current_page": 0, "total_pages": 1,
196
+ "file_type": "image", "is_parsed": False, "results": []})
197
+ return img, "Page 1 / 1"
198
+
199
+ def turn_page(direction: str):
200
+ global pdf_cache
201
+ if not pdf_cache["images"]:
202
+ return None, '<div class="page-info">No file loaded</div>', "No results yet"
203
+ if direction == "prev":
204
+ pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
205
+ elif direction == "next":
206
+ pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
207
+ idx = pdf_cache["current_page"]
208
+ img = pdf_cache["images"][idx]
209
+ page_info_html = f'<div class="page-info">Page {idx + 1} / {pdf_cache["total_pages"]}</div>'
210
+ markdown_content = "Page not processed yet"
211
+ if pdf_cache["is_parsed"] and idx < len(pdf_cache["results"]):
212
+ markdown_content = pdf_cache["results"][idx]
213
+ if is_arabic_text(markdown_content):
214
+ markdown_content = gr.update(value=markdown_content, rtl=True)
215
+ return img, page_info_html, markdown_content
216
+
217
+ def create_gradio_interface():
218
+ css = ".page-info {text-align: center;padding: 8px 16px;border-radius: 20px;font-weight: bold;margin: 10px 0;}"
219
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
220
+ gr.HTML("<h1 style='text-align:center'>🔍 Dot-OCR - Extracted Content Only</h1>")
221
+ with gr.Row():
222
+ with gr.Column(scale=1):
223
+ file_input = gr.File(label="Upload Image or PDF", file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"], type="filepath")
224
+ image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
225
+ with gr.Row():
226
+ prev_page_btn = gr.Button("◀ Previous")
227
+ page_info = gr.HTML('<div class="page-info">No file loaded</div>')
228
+ next_page_btn = gr.Button("Next ▶")
229
+ process_btn = gr.Button("🚀 Process Document", variant="primary")
230
+ clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
231
+ with gr.Column(scale=2):
232
+ markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", height=500)
233
+ file_input.change(load_file_for_preview, inputs=file_input, outputs=[image_preview, page_info])
234
+ prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output])
235
+ next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output])
236
+ process_btn.click(lambda f: _process_document(f), inputs=file_input, outputs=[markdown_output])
237
+ clear_btn.click(lambda: (None, None, '<div class="page-info">No file loaded</div>', "Click 'Process Document' to see extracted content..."),
238
+ outputs=[file_input, image_preview, page_info, markdown_output])
239
+ return demo
240
+
241
+ def _process_document(file_path):
242
+ global pdf_cache
243
+ if not file_path:
244
+ return "Please upload a file first."
245
+ img, _ = load_file_for_preview(file_path)
246
+ results = []
247
+ for page_img in pdf_cache["images"]:
248
+ md_content, _ = process_image(page_img)
249
+ results.append(md_content)
250
+ pdf_cache["results"] = results
251
+ pdf_cache["is_parsed"] = True
252
+ combined_md = "\n\n---\n\n".join(results)
253
+ if is_arabic_text(combined_md):
254
+ return gr.update(value=combined_md, rtl=True)
255
+ return combined_md
256
+
257
+ if __name__ == "__main__":
258
+ demo = create_gradio_interface()
259
+ demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860)