kirp commited on
Commit
f313d57
·
verified ·
1 Parent(s): 5465152

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +13 -8
README.md CHANGED
@@ -10,15 +10,19 @@ pip install git+https://github.com/tic-top/transformers.git
10
  from transformers import AutoModelForVision2Seq, AutoProcessor
11
  from PIL import Image
12
  import torch
13
- device = "cuda:0"
14
- repo = "kirp/kosmos2_5"
15
- dtype = torch.bfloat16
16
- model = AutoModelForVision2Seq.from_pretrained(repo, device_map = device).to(dtype)
 
 
 
17
  processor = AutoProcessor.from_pretrained(repo)
18
 
19
  path = "receipt_00008.png"
20
  image = Image.open(path)
21
- prompt = "<ocr>" # "<md>"
 
22
  inputs = processor(text=prompt, images=image, return_tensors="pt", max_patches=4096)
23
 
24
  raw_width, raw_height = image.size
@@ -29,8 +33,8 @@ scale_width = raw_width / width
29
  inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
30
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
31
  with torch.no_grad():
32
- generated_text = model.generate(**inputs, max_new_tokens=256)
33
-
34
  import re, os
35
  def postprocess(y, scale_height, scale_width, result_path=None):
36
  y = (
@@ -40,6 +44,7 @@ def postprocess(y, scale_height, scale_width, result_path=None):
40
  .replace("</image>", "")
41
  .replace(prompt, "")
42
  )
 
43
  pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
44
  bboxs_raw = re.findall(pattern, y)
45
  lines = re.split(pattern, y)[1:]
@@ -67,5 +72,5 @@ def postprocess(y, scale_height, scale_width, result_path=None):
67
  else:
68
  print(info)
69
 
70
- postprocess(processor.batch_decode(generated_text)[0],scale_height, scale_width)
71
  ```
 
10
  from transformers import AutoModelForVision2Seq, AutoProcessor
11
  from PIL import Image
12
  import torch
13
+ device = "cuda:2"
14
+ repo = "kosmos2_5"
15
+ dtype = torch.float16
16
+ # dtype = torch.bfloat16
17
+ model = AutoModelForVision2Seq.from_pretrained(repo, device_map = device, torch_dtype=dtype)
18
+ # print(model)
19
+ # exit(0)
20
  processor = AutoProcessor.from_pretrained(repo)
21
 
22
  path = "receipt_00008.png"
23
  image = Image.open(path)
24
+ prompt = "<ocr>"
25
+ # prompt = "<md>"
26
  inputs = processor(text=prompt, images=image, return_tensors="pt", max_patches=4096)
27
 
28
  raw_width, raw_height = image.size
 
33
  inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
34
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
35
  with torch.no_grad():
36
+ generated_text = model.generate(**inputs, max_length=4096)
37
+ generated_text = processor.batch_decode(generated_text)
38
  import re, os
39
  def postprocess(y, scale_height, scale_width, result_path=None):
40
  y = (
 
44
  .replace("</image>", "")
45
  .replace(prompt, "")
46
  )
47
+ print(y)
48
  pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
49
  bboxs_raw = re.findall(pattern, y)
50
  lines = re.split(pattern, y)[1:]
 
72
  else:
73
  print(info)
74
 
75
+ postprocess(generated_text[0], scale_height, scale_width)
76
  ```