Mageia commited on
Commit
a749292
·
unverified ·
1 Parent(s): d0f2987

fix: process pdf once

Browse files
Files changed (2) hide show
  1. app-ocr.py +63 -127
  2. app.py +80 -36
app-ocr.py CHANGED
@@ -1,146 +1,82 @@
1
  import base64
2
- import io
3
  import os
4
- import shutil
5
- import time
6
- import uuid
7
- from pathlib import Path
8
 
9
  import gradio as gr
10
- from modelscope import AutoModel, AutoTokenizer
 
 
11
 
12
- UPLOAD_FOLDER = "./uploads"
13
- RESULTS_FOLDER = "./results"
14
 
 
 
 
15
 
16
- tokenizer = AutoTokenizer.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True)
17
- model = AutoModel.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True, low_cpu_mem_usage=True, device_map="cuda", use_safetensors=True)
18
- model = model.eval().cuda()
19
 
20
- for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
21
- if not os.path.exists(folder):
22
- os.makedirs(folder)
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- def image_to_base64(image):
26
- buffered = io.BytesIO()
27
- image.save(buffered, format="PNG")
28
- return base64.b64encode(buffered.getvalue()).decode()
29
 
 
 
30
 
31
- def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
32
- unique_id = str(uuid.uuid4())
33
- image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
34
- result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
35
 
36
- shutil.copy(image, image_path)
 
 
 
 
37
 
38
- try:
39
- if got_mode == "plain texts OCR":
40
- res = model.chat(tokenizer, image_path, ocr_type="ocr")
41
- return res, None
42
- elif got_mode == "format texts OCR":
43
- res = model.chat(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
44
- elif got_mode == "plain multi-crop OCR":
45
- res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
46
- return res, None
47
- elif got_mode == "format multi-crop OCR":
48
- res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
49
- elif got_mode == "plain fine-grained OCR":
50
- res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
51
- return res, None
52
- elif got_mode == "format fine-grained OCR":
53
- res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
54
-
55
- # res_markdown = f"$$ {res} $$"
56
- res_markdown = res
57
-
58
- if "format" in got_mode and os.path.exists(result_path):
59
- with open(result_path, "r") as f:
60
- html_content = f.read()
61
- encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
62
- iframe_src = f"data:text/html;base64,{encoded_html}"
63
- iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
64
- download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}.html">Download Full Result</a>'
65
- return res_markdown, f"{download_link}<br>{iframe}"
66
- else:
67
- return res_markdown, None
68
- except Exception as e:
69
- return f"Error: {str(e)}", None
70
- finally:
71
- if os.path.exists(image_path):
72
- os.remove(image_path)
73
-
74
-
75
- def task_update(task):
76
- if "fine-grained" in task:
77
- return [
78
- gr.update(visible=True),
79
- gr.update(visible=False),
80
- gr.update(visible=False),
81
- ]
82
- else:
83
- return [
84
- gr.update(visible=False),
85
- gr.update(visible=False),
86
- gr.update(visible=False),
87
- ]
88
-
89
-
90
- def fine_grained_update(task):
91
- if task == "box":
92
- return [
93
- gr.update(visible=False, value=""),
94
- gr.update(visible=True),
95
- ]
96
- elif task == "color":
97
- return [
98
- gr.update(visible=True),
99
- gr.update(visible=False, value=""),
100
- ]
101
-
102
-
103
- def cleanup_old_files():
104
- current_time = time.time()
105
- for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
106
- for file_path in Path(folder).glob("*"):
107
- if current_time - file_path.stat().st_mtime > 3600: # 1 hour
108
- file_path.unlink()
109
 
 
110
 
111
- with gr.Blocks() as demo:
112
- with gr.Row():
113
- with gr.Column():
114
- image_input = gr.Image(type="filepath", label="上传图片")
115
- task_dropdown = gr.Dropdown(
116
- choices=[
117
- "plain texts OCR",
118
- "format texts OCR",
119
- "plain multi-crop OCR",
120
- "format multi-crop OCR",
121
- "plain fine-grained OCR",
122
- "format fine-grained OCR",
123
- ],
124
- label="选择GOT模式",
125
- value="plain texts OCR",
126
- )
127
- fine_grained_dropdown = gr.Dropdown(choices=["box", "color"], label="fine-grained type", visible=False)
128
- color_dropdown = gr.Dropdown(choices=["red", "green", "blue"], label="color list", visible=False)
129
- box_input = gr.Textbox(label="input box: [x1,y1,x2,y2]", placeholder="e.g., [0,0,100,100]", visible=False)
130
- submit_button = gr.Button("Submit")
131
-
132
- with gr.Column():
133
- ocr_result = gr.Textbox(label="GOT output")
134
-
135
- with gr.Column():
136
- gr.Markdown("**如果选择带格式的模式,mathpix结果将自动呈现如下:**")
137
- html_result = gr.HTML(label="rendered html", show_label=True)
138
-
139
- task_dropdown.change(task_update, inputs=[task_dropdown], outputs=[fine_grained_dropdown, color_dropdown, box_input])
140
- fine_grained_dropdown.change(fine_grained_update, inputs=[fine_grained_dropdown], outputs=[color_dropdown, box_input])
141
-
142
- submit_button.click(run_GOT, inputs=[image_input, task_dropdown, fine_grained_dropdown, color_dropdown, box_input], outputs=[ocr_result, html_result])
143
 
144
  if __name__ == "__main__":
145
- cleanup_old_files()
146
  demo.launch()
 
1
  import base64
 
2
  import os
 
 
 
 
3
 
4
  import gradio as gr
5
+ import spaces
6
+ import torch
7
+ from transformers import AutoModel, AutoTokenizer
8
 
9
+ model_name = "ucaslcl/GOT-OCR2_0"
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
14
+ model = model.eval().to(device)
15
 
 
 
 
16
 
17
+ @spaces.GPU()
18
+ def ocr_process(image, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()):
19
+ if image is None:
20
+ return "错误:未提供图片"
21
 
22
+ try:
23
+ image_path = image
24
+ result_path = f"{os.path.splitext(image_path)[0]}_result.html"
25
+
26
+ progress(0, desc="开始处理...")
27
+
28
+ if "plain" in got_mode:
29
+ progress(0.3, desc="执行OCR识别...")
30
+ if "multi-crop" in got_mode:
31
+ res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
32
+ else:
33
+ res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
34
+ progress(1, desc="处理完成")
35
+ return res
36
+ elif "format" in got_mode:
37
+ progress(0.3, desc="执行OCR识别...")
38
+ if "multi-crop" in got_mode:
39
+ res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
40
+ else:
41
+ res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
42
+
43
+ progress(0.7, desc="生成结果...")
44
+ if os.path.exists(result_path):
45
+ with open(result_path, "r", encoding="utf-8") as f:
46
+ html_content = f.read()
47
+ encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
48
+ data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
49
+ preview = f'<iframe src="{data_uri}" width="100%" height="600px"></iframe>'
50
+ download_link = f'<a href="{data_uri}" download="result.html">下载完整结果</a>'
51
+ progress(1, desc="处理完成")
52
+ return f"{download_link}\n\n{preview}"
53
+
54
+ return "错误: 未知的OCR模式"
55
+ except Exception as e:
56
+ return f"错误: {str(e)}"
57
 
 
 
 
 
58
 
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("# OCR 图像识别")
61
 
62
+ with gr.Row():
63
+ image_input = gr.Image(type="filepath", label="上传图片")
 
 
64
 
65
+ got_mode = gr.Dropdown(
66
+ choices=["plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR"],
67
+ label="OCR模式",
68
+ value="plain texts OCR",
69
+ )
70
 
71
+ with gr.Row():
72
+ ocr_color = gr.Textbox(label="OCR颜色 (仅用于fine-grained模式)")
73
+ ocr_box = gr.Textbox(label="OCR边界框 (仅用于fine-grained模式)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ submit_button = gr.Button("开始OCR识别")
76
 
77
+ output = gr.HTML(label="识别结果")
78
+
79
+ submit_button.click(ocr_process, inputs=[image_input, got_mode, ocr_color, ocr_box], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  if __name__ == "__main__":
 
82
  demo.launch()
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import base64
2
  import os
3
- import time
4
 
 
5
  import gradio as gr
6
  import spaces
7
  import torch
 
8
  from transformers import AutoModel, AutoTokenizer
9
 
10
  model_name = "ucaslcl/GOT-OCR2_0"
@@ -15,53 +17,95 @@ model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map
15
  model = model.eval().to(device)
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @spaces.GPU()
19
- def ocr_process(image, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()):
20
- if image is None:
21
- return "错误:未提供图片"
22
 
23
  try:
24
- image_path = image
25
- result_path = f"{os.path.splitext(image_path)[0]}_result.html"
26
-
27
  progress(0, desc="开始处理...")
28
 
29
- if "plain" in got_mode:
30
- progress(0.3, desc="执行OCR识别...")
31
- if "multi-crop" in got_mode:
32
- res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
33
- else:
34
- res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
35
- progress(1, desc="处理完成")
36
- return res
37
- elif "format" in got_mode:
38
- progress(0.3, desc="执行OCR识别...")
39
- if "multi-crop" in got_mode:
40
- res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
 
 
 
 
 
 
 
41
  else:
42
- res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
43
-
44
- progress(0.7, desc="生成结果...")
45
- if os.path.exists(result_path):
46
- with open(result_path, "r", encoding="utf-8") as f:
47
- html_content = f.read()
48
- encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
49
- data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
50
- preview = f'<iframe src="{data_uri}" width="100%" height="600px"></iframe>'
51
- download_link = f'<a href="{data_uri}" download="result.html">下载完整结果</a>'
52
- progress(1, desc="处理完成")
53
- return f"{download_link}\n\n{preview}"
54
-
55
- return "错误: 未知的OCR模式"
56
  except Exception as e:
57
  return f"错误: {str(e)}"
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  with gr.Blocks() as demo:
61
  gr.Markdown("# OCR 图像识别")
62
 
63
- with gr.Row():
64
- image_input = gr.Image(type="filepath", label="上传图片")
65
 
66
  got_mode = gr.Dropdown(
67
  choices=["plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR"],
@@ -77,7 +121,7 @@ with gr.Blocks() as demo:
77
 
78
  output = gr.HTML(label="识别结果")
79
 
80
- submit_button.click(ocr_process, inputs=[image_input, got_mode, ocr_color, ocr_box], outputs=output)
81
 
82
  if __name__ == "__main__":
83
  demo.launch()
 
1
  import base64
2
  import os
3
+ import tempfile
4
 
5
+ import fitz
6
  import gradio as gr
7
  import spaces
8
  import torch
9
+ from PIL import Image, ImageEnhance
10
  from transformers import AutoModel, AutoTokenizer
11
 
12
  model_name = "ucaslcl/GOT-OCR2_0"
 
17
  model = model.eval().to(device)
18
 
19
 
20
+ def pdf_to_images(pdf_path):
21
+ images = []
22
+ pdf_document = fitz.open(pdf_path)
23
+ for page_num in range(len(pdf_document)):
24
+ page = pdf_document.load_page(page_num)
25
+ zoom = 10 # 增加缩放比例到10
26
+ mat = fitz.Matrix(zoom, zoom)
27
+ pix = page.get_pixmap(matrix=mat, alpha=False)
28
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
29
+
30
+ # 增对比度
31
+ enhancer = ImageEnhance.Contrast(img)
32
+ img = enhancer.enhance(1.5) # 增加50%的对比度
33
+
34
+ images.append(img)
35
+ pdf_document.close()
36
+ return images
37
+
38
+
39
  @spaces.GPU()
40
+ def ocr_process(file, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()):
41
+ if file is None:
42
+ return "错误:未提供文件"
43
 
44
  try:
 
 
 
45
  progress(0, desc="开始处理...")
46
 
47
+ with tempfile.TemporaryDirectory() as temp_dir:
48
+ if file.name.lower().endswith(".pdf"):
49
+ pdf_path = os.path.join(temp_dir, "temp.pdf")
50
+ with open(pdf_path, "wb") as f:
51
+ f.write(file.read())
52
+
53
+ images = pdf_to_images(pdf_path)
54
+ num_pages = len(images)
55
+ results = []
56
+
57
+ for i, image in enumerate(images):
58
+ progress((i + 1) / num_pages, desc=f"处理第 {i+1}/{num_pages} 页...")
59
+ img_path = os.path.join(temp_dir, f"page_{i+1}.png")
60
+ image.save(img_path, "PNG")
61
+
62
+ result = process_single_image(img_path, got_mode, ocr_color, ocr_box)
63
+ results.append(f"第 {i+1} 页结果:\n{result}")
64
+
65
+ final_result = "\n\n".join(results)
66
  else:
67
+ img_path = os.path.join(temp_dir, "temp_image.png")
68
+ with open(img_path, "wb") as f:
69
+ f.write(file.read())
70
+ final_result = process_single_image(img_path, got_mode, ocr_color, ocr_box)
71
+
72
+ progress(1, desc="处理完成")
73
+ return final_result
 
 
 
 
 
 
 
74
  except Exception as e:
75
  return f"错误: {str(e)}"
76
 
77
 
78
+ def process_single_image(image_path, got_mode, ocr_color, ocr_box):
79
+ result_path = f"{os.path.splitext(image_path)[0]}_result.html"
80
+
81
+ if "plain" in got_mode:
82
+ if "multi-crop" in got_mode:
83
+ res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
84
+ else:
85
+ res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
86
+ return res
87
+ elif "format" in got_mode:
88
+ if "multi-crop" in got_mode:
89
+ res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
90
+ else:
91
+ res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
92
+
93
+ if os.path.exists(result_path):
94
+ with open(result_path, "r", encoding="utf-8") as f:
95
+ html_content = f.read()
96
+ encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
97
+ data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
98
+ preview = f'<iframe src="{data_uri}" width="100%" height="600px"></iframe>'
99
+ download_link = f'<a href="{data_uri}" download="result.html">下载完整结果</a>'
100
+ return f"{download_link}\n\n{preview}\n\n识别结果:\n{res}"
101
+
102
+ return "错误: 未知的OCR模式"
103
+
104
+
105
  with gr.Blocks() as demo:
106
  gr.Markdown("# OCR 图像识别")
107
 
108
+ file_input = gr.File(label="上传PDF或图片文件")
 
109
 
110
  got_mode = gr.Dropdown(
111
  choices=["plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR"],
 
121
 
122
  output = gr.HTML(label="识别结果")
123
 
124
+ submit_button.click(ocr_process, inputs=[file_input, got_mode, ocr_color, ocr_box], outputs=output)
125
 
126
  if __name__ == "__main__":
127
  demo.launch()