[email protected]
commited on
Commit
•
fa7ceb4
1
Parent(s):
629edc1
comment batch generating
Browse files- ocr.py +7 -7
- output.png +2 -2
ocr.py
CHANGED
@@ -5,7 +5,7 @@ from PIL import Image, ImageDraw
|
|
5 |
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
|
6 |
|
7 |
repo = "microsoft/kosmos-2.5"
|
8 |
-
device = "cuda:
|
9 |
dtype = torch.bfloat16
|
10 |
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype)
|
11 |
processor = AutoProcessor.from_pretrained(repo)
|
@@ -22,12 +22,12 @@ raw_width, raw_height = image.size
|
|
22 |
scale_height = raw_height / height
|
23 |
scale_width = raw_width / width
|
24 |
|
25 |
-
# bs > 1, batch
|
26 |
-
inputs = processor(text=[prompt, prompt], images=[image,image], return_tensors="pt")
|
27 |
-
height, width = inputs.pop("height"), inputs.pop("width")
|
28 |
-
raw_width, raw_height = image.size
|
29 |
-
scale_height = raw_height / height[0]
|
30 |
-
scale_width = raw_width / width[0]
|
31 |
|
32 |
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
|
33 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
|
|
5 |
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
|
6 |
|
7 |
repo = "microsoft/kosmos-2.5"
|
8 |
+
device = "cuda:0"
|
9 |
dtype = torch.bfloat16
|
10 |
model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype)
|
11 |
processor = AutoProcessor.from_pretrained(repo)
|
|
|
22 |
scale_height = raw_height / height
|
23 |
scale_width = raw_width / width
|
24 |
|
25 |
+
# bs > 1, batch generation
|
26 |
+
# inputs = processor(text=[prompt, prompt], images=[image,image], return_tensors="pt")
|
27 |
+
# height, width = inputs.pop("height"), inputs.pop("width")
|
28 |
+
# raw_width, raw_height = image.size
|
29 |
+
# scale_height = raw_height / height[0]
|
30 |
+
# scale_width = raw_width / width[0]
|
31 |
|
32 |
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
|
33 |
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
|
output.png
CHANGED
Git LFS Details
|
Git LFS Details
|