Update README.md
Browse files
README.md
CHANGED
@@ -10,15 +10,19 @@ pip install git+https://github.com/tic-top/transformers.git
|
|
10 |
from transformers import AutoModelForVision2Seq, AutoProcessor
|
11 |
from PIL import Image
|
12 |
import torch
|
13 |
-
device = "cuda:
|
14 |
-
repo = "
|
15 |
-
dtype = torch.
|
16 |
-
|
|
|
|
|
|
|
17 |
processor = AutoProcessor.from_pretrained(repo)
|
18 |
|
19 |
path = "receipt_00008.png"
|
20 |
image = Image.open(path)
|
21 |
-
prompt = "<ocr>"
|
|
|
22 |
inputs = processor(text=prompt, images=image, return_tensors="pt", max_patches=4096)
|
23 |
|
24 |
raw_width, raw_height = image.size
|
@@ -29,8 +33,8 @@ scale_width = raw_width / width
|
|
29 |
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
|
30 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
31 |
with torch.no_grad():
|
32 |
-
generated_text = model.generate(**inputs,
|
33 |
-
|
34 |
import re, os
|
35 |
def postprocess(y, scale_height, scale_width, result_path=None):
|
36 |
y = (
|
@@ -40,6 +44,7 @@ def postprocess(y, scale_height, scale_width, result_path=None):
|
|
40 |
.replace("</image>", "")
|
41 |
.replace(prompt, "")
|
42 |
)
|
|
|
43 |
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
|
44 |
bboxs_raw = re.findall(pattern, y)
|
45 |
lines = re.split(pattern, y)[1:]
|
@@ -67,5 +72,5 @@ def postprocess(y, scale_height, scale_width, result_path=None):
|
|
67 |
else:
|
68 |
print(info)
|
69 |
|
70 |
-
postprocess(
|
71 |
```
|
|
|
10 |
from transformers import AutoModelForVision2Seq, AutoProcessor
|
11 |
from PIL import Image
|
12 |
import torch
|
13 |
+
device = "cuda:2"
|
14 |
+
repo = "kosmos2_5"
|
15 |
+
dtype = torch.float16
|
16 |
+
# dtype = torch.bfloat16
|
17 |
+
model = AutoModelForVision2Seq.from_pretrained(repo, device_map = device, torch_dtype=dtype)
|
18 |
+
# print(model)
|
19 |
+
# exit(0)
|
20 |
processor = AutoProcessor.from_pretrained(repo)
|
21 |
|
22 |
path = "receipt_00008.png"
|
23 |
image = Image.open(path)
|
24 |
+
prompt = "<ocr>"
|
25 |
+
# prompt = "<md>"
|
26 |
inputs = processor(text=prompt, images=image, return_tensors="pt", max_patches=4096)
|
27 |
|
28 |
raw_width, raw_height = image.size
|
|
|
33 |
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
|
34 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
35 |
with torch.no_grad():
|
36 |
+
generated_text = model.generate(**inputs, max_length=4096)
|
37 |
+
generated_text = processor.batch_decode(generated_text)
|
38 |
import re, os
|
39 |
def postprocess(y, scale_height, scale_width, result_path=None):
|
40 |
y = (
|
|
|
44 |
.replace("</image>", "")
|
45 |
.replace(prompt, "")
|
46 |
)
|
47 |
+
print(y)
|
48 |
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
|
49 |
bboxs_raw = re.findall(pattern, y)
|
50 |
lines = re.split(pattern, y)[1:]
|
|
|
72 |
else:
|
73 |
print(info)
|
74 |
|
75 |
+
postprocess(generated_text[0], scale_height, scale_width)
|
76 |
```
|