bokesyo commited on
Commit
d90ed2d
·
1 Parent(s): b3fffcd
Files changed (1) hide show
  1. app.py +120 -133
app.py CHANGED
@@ -16,7 +16,17 @@ import fitz
16
  import threading
17
  import gradio as gr
18
  import spaces
 
 
 
 
 
 
 
 
19
 
 
 
20
 
21
  def get_image_md5(img: Image.Image):
22
  img_byte_array = img.tobytes()
@@ -25,152 +35,129 @@ def get_image_md5(img: Image.Image):
25
  hex_digest = hash_md5.hexdigest()
26
  return hex_digest
27
 
28
- def pdf_to_images(pdf_path, dpi=100):
29
- doc = fitz.open(pdf_path)
30
- images = []
31
- for page in tqdm.tqdm(doc):
32
- pix = page.get_pixmap(dpi=dpi)
33
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
34
- images.append(img)
35
- return images
36
-
37
  def calculate_md5_from_binary(binary_data):
38
  hash_md5 = hashlib.md5()
39
  hash_md5.update(binary_data)
40
  return hash_md5.hexdigest()
41
 
42
- class PDFVisualRetrieval:
43
- def __init__(self, model, tokenizer):
44
- self.tokenizer = tokenizer
45
- self.model = model
46
- self.reps = {}
47
- self.images = {}
48
- self.lock = threading.Lock()
49
 
50
- def retrieve(self, knowledge_base: str, query: str, topk: int):
51
- doc_reps = list(self.reps[knowledge_base].values())
52
- query_with_instruction = "Represent this query for retrieving relavant document: " + query
53
- with torch.no_grad():
54
- query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
55
- doc_reps_cat = torch.stack(doc_reps, dim=0)
56
- similarities = torch.matmul(query_rep, doc_reps_cat.T)
57
- topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
58
- topk_values_np = topk_values.cpu().numpy()
59
- topk_doc_ids_np = topk_doc_ids.cpu().numpy()
60
- similarities_np = similarities.cpu().numpy()
61
- all_images_doc_list = list(self.images[knowledge_base].values())
62
- images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
63
- return topk_doc_ids_np, topk_values_np, images_topk
64
 
65
- def add_pdf(self, knowledge_base_name: str, pdf_file_path: str, dpi: int = 100):
66
- if knowledge_base_name not in self.reps:
67
- self.reps[knowledge_base_name] = {}
68
- if knowledge_base_name not in self.images:
69
- self.images[knowledge_base_name] = {}
70
- doc = fitz.open(pdf_file_path)
71
- print("model encoding images..")
72
- for page in tqdm.tqdm(doc):
73
- pix = page.get_pixmap(dpi=dpi)
74
- image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
75
- image_md5 = get_image_md5(image)
76
- with torch.no_grad():
77
- reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
78
- self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
79
- self.images[knowledge_base_name][image_md5] = image
80
- return
81
-
82
- def add_pdf_gradio(self, pdf_file_binary, progress=gr.Progress()):
83
- knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
84
- if knowledge_base_name not in self.reps:
85
- self.reps[knowledge_base_name] = {}
86
- else:
87
- return knowledge_base_name
88
- if knowledge_base_name not in self.images:
89
- self.images[knowledge_base_name] = {}
90
- dpi = 100
91
- doc = fitz.open("pdf", pdf_file_binary)
92
-
93
- for page in progress.tqdm(doc):
94
- # with self.lock: # because we hope one 16G gpu only process one image at the same time
95
- pix = page.get_pixmap(dpi=dpi)
96
- image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
97
- image_md5 = get_image_md5(image)
98
- with torch.no_grad():
99
- reps = self.model(text=[''], image=[image], tokenizer=self.tokenizer).reps
100
- self.reps[knowledge_base_name][image_md5] = reps.squeeze(0)
101
- self.images[knowledge_base_name][image_md5] = image
102
-
103
- return knowledge_base_name
104
-
105
- def retrieve_gradio(self, knowledge_base: str, query: str, topk: int):
106
- doc_reps = list(self.reps[knowledge_base].values())
107
- query_with_instruction = "Represent this query for retrieving relavant document: " + query
108
  with torch.no_grad():
109
- query_rep = self.model(text=[query_with_instruction], image=[None], tokenizer=self.tokenizer).reps.squeeze(0)
110
- doc_reps_cat = torch.stack(doc_reps, dim=0)
111
- similarities = torch.matmul(query_rep, doc_reps_cat.T)
112
- topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
113
- topk_values_np = topk_values.cpu().numpy()
114
- topk_doc_ids_np = topk_doc_ids.cpu().numpy()
115
- similarities_np = similarities.cpu().numpy()
116
- all_images_doc_list = list(self.images[knowledge_base].values())
117
- images_topk = [all_images_doc_list[idx] for idx in topk_doc_ids_np]
118
- return images_topk
119
-
120
-
121
- if __name__ == "__main__":
122
- from transformers import AutoModel
123
- from transformers import AutoTokenizer
124
- from PIL import Image
125
- import torch
126
 
127
- device = 'cuda:0'
 
 
 
 
 
 
 
 
 
128
 
129
- # Load model, be sure to substitute `model_path` by your model path
130
- model_path = 'RhapsodyAI/minicpm-visual-embedding-v0' # replace with your local model path
131
- # pdf_path = "/home/jeeves/xubokai/minicpm-visual-embedding-v0/2406.07422v1.pdf"
 
132
 
133
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
134
- model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
135
- model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- retriever = PDFVisualRetrieval(model=model, tokenizer=tokenizer)
 
 
 
138
 
139
- @spaces.GPU
140
- def add_pdf_gradio(pdf_file_binary):
141
- return retriever.add_pdf_gradio(pdf_file_binary)
 
 
 
 
142
 
143
- @spaces.GPU
144
- def retrieve_gradio(knowledge_base, query, topk):
145
- return retriever.retrieve_gradio(knowledge_base, query, topk)
146
-
147
- # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='what is the number of VQ of this kind of codec method?', topk=1)
148
- # # 2
149
- # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the training loss curve of this paper?', topk=1)
150
- # # 3
151
- # topk_doc_ids_np, topk_values_np, images_topk = retriever.retrieve(knowledge_base='test', query='the experiment table?', topk=1)
152
- # # 2
153
 
154
- with gr.Blocks() as app:
155
- gr.Markdown("# Memex: OCR-free Visual Document Retrieval @RhapsodyAI")
156
-
157
- with gr.Row():
158
- file_input = gr.File(type="binary", label="Upload PDF")
159
- file_result = gr.Text(label="Knowledge Base ID (remember this!)")
160
- process_button = gr.Button("Process PDF")
161
-
162
- process_button.click(add_pdf_gradio, inputs=[file_input], outputs=file_result)
163
-
164
- with gr.Row():
165
- kb_id_input = gr.Text(label="Your Knowledge Base ID")
166
- query_input = gr.Text(label="Your Queston")
167
- topk_input = inputs=gr.Number(value=1, minimum=1, maximum=5, step=1, label="Top K")
168
- retrieve_button = gr.Button("Retrieve")
169
-
170
- with gr.Row():
171
- images_output = gr.Gallery(label="Retrieved Pages")
172
-
173
- retrieve_button.click(retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
174
-
175
- app.launch()
176
 
 
16
  import threading
17
  import gradio as gr
18
  import spaces
19
+ import os
20
+ from transformers import AutoModel
21
+ from transformers import AutoTokenizer
22
+ from PIL import Image
23
+ import torch
24
+ import os
25
+ import numpy as np
26
+ import json
27
 
28
+ cache_dir = '/data/kb_cache'
29
+ os.makedirs(cache_dir, exist_ok=True)
30
 
31
  def get_image_md5(img: Image.Image):
32
  img_byte_array = img.tobytes()
 
35
  hex_digest = hash_md5.hexdigest()
36
  return hex_digest
37
 
 
 
 
 
 
 
 
 
 
38
  def calculate_md5_from_binary(binary_data):
39
  hash_md5 = hashlib.md5()
40
  hash_md5.update(binary_data)
41
  return hash_md5.hexdigest()
42
 
43
+ @spaces.GPU(duration=120)
44
+ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
45
+ global model, tokenizer
46
+
47
+ knowledge_base_name = calculate_md5_from_binary(pdf_file_binary)
 
 
48
 
49
+ this_cache_dir = os.path.join(cache_dir, knowledge_base_name)
50
+ os.makedirs(this_cache_dir, exist_ok=True)
51
+
52
+ with open(os.path.join(this_cache_dir, f"src.pdf"), 'wb') as file:
53
+ file.write(pdf_file_binary)
54
+
55
+ dpi = 100
56
+ doc = fitz.open("pdf", pdf_file_binary)
 
 
 
 
 
 
57
 
58
+ reps_list = []
59
+ images = []
60
+ image_md5s = []
61
+
62
+ for page in progress.tqdm(doc):
63
+ # with self.lock: # because we hope one 16G gpu only process one image at the same time
64
+ pix = page.get_pixmap(dpi=dpi)
65
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
66
+ image_md5 = get_image_md5(image)
67
+ image_md5s.append(image_md5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with torch.no_grad():
69
+ reps = model(text=[''], image=[image], tokenizer=tokenizer).reps
70
+ reps_list.append(reps.squeeze(0).cpu().numpy())
71
+ images.append(image)
72
+
73
+ for idx in range(len(images)):
74
+ image = images[idx]
75
+ image_md5 = image_md5s[idx]
76
+ cache_image_path = os.path.join(this_cache_dir, f"{image_md5}.png")
77
+ image.save(cache_image_path)
78
+
79
+ np.save(os.path.join(this_cache_dir, f"reps.npy"), reps_list)
80
+
81
+ with open(os.path.join(this_cache_dir, f"md5s.txt"), 'w') as f:
82
+ for item in image_md5s:
83
+ f.write(item+'\n')
 
 
84
 
85
+ return knowledge_base_name
86
+
87
+ # @spaces.GPU
88
+ def retrieve_gradio(knowledge_base: str, query: str, topk: int):
89
+ global model, tokenizer
90
+
91
+ target_cache_dir = os.path.join(cache_dir, knowledge_base)
92
+
93
+ if not os.path.exists(target_cache_dir):
94
+ return None
95
 
96
+ md5s = []
97
+ with open(os.path.join(target_cache_dir, f"md5s.txt"), 'r') as f:
98
+ for line in f:
99
+ md5s.append(line.rstrip('\n'))
100
 
101
+ doc_reps = np.load(os.path.join(target_cache_dir, f"reps.npy"))
102
+
103
+ query_with_instruction = "Represent this query for retrieving relavant document: " + query
104
+ with torch.no_grad():
105
+ query_rep = model(text=[query_with_instruction], image=[None], tokenizer=tokenizer).reps.squeeze(0).cpu()
106
+
107
+ query_md5 = hashlib.md5(query.encode()).hexdigest()
108
+ with open(os.path.join(target_cache_dir, f"q-{query_md5}.json"), 'w') as f:
109
+ f.write(json.dumps(
110
+ {
111
+ "query": query
112
+ }, indent=4, ensure_ascii=False
113
+ ))
114
+
115
+ doc_reps_cat = torch.stack([torch.Tensor(i) for i in doc_reps], dim=0)
116
+
117
+ similarities = torch.matmul(query_rep, doc_reps_cat.T)
118
+
119
+ topk_values, topk_doc_ids = torch.topk(similarities, k=topk)
120
+
121
+ topk_values_np = topk_values.cpu().numpy()
122
+
123
+ topk_doc_ids_np = topk_doc_ids.cpu().numpy()
124
+
125
+ similarities_np = similarities.cpu().numpy()
126
+
127
+ images_topk = [Image.open(os.path.join(target_cache_dir, f"{md5s[idx]}.png")) for idx in topk_doc_ids_np]
128
+
129
+ return images_topk
130
+
131
+
132
+ device = 'cuda'
133
+ model_path = 'RhapsodyAI/minicpm-visual-embedding-v0' # replace with your local model path
134
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
135
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
136
+ model.to(device)
137
+
138
+
139
+ with gr.Blocks() as app:
140
+ gr.Markdown("# Memex: OCR-free Visual Document Retrieval @RhapsodyAI")
141
 
142
+ with gr.Row():
143
+ file_input = gr.File(type="binary", label="Upload PDF")
144
+ file_result = gr.Text(label="Knowledge Base ID (remember this!)")
145
+ process_button = gr.Button("Process PDF")
146
 
147
+ process_button.click(add_pdf_gradio, inputs=[file_input], outputs=file_result)
148
+
149
+ with gr.Row():
150
+ kb_id_input = gr.Text(label="Your Knowledge Base ID")
151
+ query_input = gr.Text(label="Your Queston")
152
+ topk_input = inputs=gr.Number(value=1, minimum=1, maximum=5, step=1, label="Top K")
153
+ retrieve_button = gr.Button("Retrieve")
154
 
155
+ with gr.Row():
156
+ images_output = gr.Gallery(label="Retrieved Pages")
 
 
 
 
 
 
 
 
157
 
158
+ retrieve_button.click(retrieve_gradio, inputs=[kb_id_input, query_input, topk_input], outputs=images_output)
159
+
160
+ gr.Markdown("By using this demo, you agree to share your use data with us for research purpose, to help improve user experience.")
161
+
162
+ app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163