sayed99 commited on
Commit
1787ca5
·
verified ·
1 Parent(s): 11556b3

docs: Readme Updated for optimized Usage with transformers library

Browse files

python code for transformers usage updated to use flash-attn as attention implementation to boost the performance and reduce memory usage.

Files changed (1) hide show
  1. README.md +64 -23
README.md CHANGED
@@ -149,49 +149,90 @@ Currently, we support inference using the PaddleOCR-VL-0.9B model with the `tran
149
  > [!NOTE]
150
  > Note: We currently recommend using the official method for inference, as it is faster and supports page-level document parsing. The example code below only supports element-level recognition.
151
 
 
 
 
 
 
 
152
  ```python
153
- from PIL import Image
154
  import torch
155
  from transformers import AutoModelForCausalLM, AutoProcessor
 
 
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
158
 
159
- CHOSEN_TASK = "ocr" # Options: 'ocr' | 'table' | 'chart' | 'formula'
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  PROMPTS = {
161
  "ocr": "OCR:",
162
  "table": "Table Recognition:",
163
- "formula": "Formula Recognition:",
164
  "chart": "Chart Recognition:",
 
165
  }
166
 
167
- model_path = "PaddlePaddle/PaddleOCR-VL"
168
- image_path = "test.png"
169
- image = Image.open(image_path).convert("RGB")
170
 
171
- model = AutoModelForCausalLM.from_pretrained(
172
- model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
173
- ).to(DEVICE).eval()
174
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
175
-
176
- messages = [
177
- {"role": "user",
178
- "content": [
179
- {"type": "image", "image": image},
180
- {"type": "text", "text": PROMPTS[CHOSEN_TASK]},
181
- ]
182
- }
183
- ]
184
  inputs = processor.apply_chat_template(
185
  messages,
186
  tokenize=True,
187
  add_generation_prompt=True,
188
  return_dict=True,
189
- return_tensors="pt"
190
  ).to(DEVICE)
191
 
192
- outputs = model.generate(**inputs, max_new_tokens=1024)
193
- outputs = processor.batch_decode(outputs, skip_special_tokens=True)[0]
194
- print(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  ```
196
 
197
  ## Performance
 
149
  > [!NOTE]
150
  > Note: We currently recommend using the official method for inference, as it is faster and supports page-level document parsing. The example code below only supports element-level recognition.
151
 
152
+ ```bash
153
+ # 1- ensure the flash-attn2 is installed
154
+ !uv pip install -q "transformers>=4.55" bitsandbytes accelerate
155
+ !uv pip install flash-attn --no-build-isolation
156
+ ```
157
+
158
  ```python
159
+ # 1.2 import the necessary libraries
160
  import torch
161
  from transformers import AutoModelForCausalLM, AutoProcessor
162
+ from PIL import Image
163
+ from google.colab import files
164
+
165
 
166
+ # 2- Upload image (drag & drop any PNG/JPG)
167
+ uploaded = files.upload()
168
+ image_path = list(uploaded.keys())[-1]
169
+ print(f"Using: {image_path}")
170
+
171
+
172
+ # 3. Resize max-2048 preserving aspect ratio
173
+ img = Image.open(image_path).convert("RGB")
174
+ max_size = 2048
175
+ w, h = img.size
176
+ if max(w, h) > max_size:
177
+ scale = max_size / max(w, h)
178
+ new_w, new_h = int(w * scale), int(h * scale)
179
+ img = img.resize((new_w, new_h), Image.LANCZOS)
180
+ print(f"Resized → {img.size[0]}×{img.size[1]}")
181
+ print(f"current dim → {img.size[0]}×{img.size[1]}")
182
+
183
+
184
+ #4. Load model
185
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
186
 
187
+ model = AutoModelForCausalLM.from_pretrained(
188
+ "PaddlePaddle/PaddleOCR-VL",
189
+ trust_remote_code=True,
190
+ torch_dtype=torch.bfloat16,
191
+ attn_implementation="flash_attention_2",
192
+ ).to(dtype=torch.bfloat16, device=DEVICE).eval()
193
+
194
+ processor = AutoProcessor.from_pretrained("PaddlePaddle/PaddleOCR-VL", trust_remote_code=True)
195
+
196
+
197
+
198
+
199
+ # 5. Choose task
200
+ TASK = "ocr" # ← change to "table" | "chart" | "formula"
201
  PROMPTS = {
202
  "ocr": "OCR:",
203
  "table": "Table Recognition:",
 
204
  "chart": "Chart Recognition:",
205
+ "formula": "Formula Recognition:",
206
  }
207
 
208
+ # 6. Run inference
209
+ messages = [{"role": "user", "content": [{"type": "image", "image": img},
210
+ {"type": "text", "text": PROMPTS[TASK]}]}]
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  inputs = processor.apply_chat_template(
213
  messages,
214
  tokenize=True,
215
  add_generation_prompt=True,
216
  return_dict=True,
217
+ return_tensors="pt"
218
  ).to(DEVICE)
219
 
220
+
221
+
222
+
223
+ # 7. Run inference
224
+ with torch.inference_mode():
225
+ out = model.generate(
226
+ **inputs,
227
+ max_new_tokens=1024,
228
+ do_sample=False,
229
+ use_cache=True
230
+ )
231
+
232
+ # 8. Decode the output
233
+ result = processor.batch_decode(out, skip_special_tokens=True)[0]
234
+ print("\n" + "="*60 + "\nRESULT:\n" + "="*60)
235
+ print(result)
236
  ```
237
 
238
  ## Performance