mychen76 commited on
Commit
0a452c7
·
verified ·
1 Parent(s): 02f3e1c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +184 -10
README.md CHANGED
@@ -5,25 +5,198 @@ tags: []
5
 
6
  # Model Card for Model ID
7
 
8
- <!-- Provide a quick summary of what the model is/does. -->
9
-
10
-
11
 
12
  ## Model Details
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- ### Model Description
15
 
16
  <!-- Provide a longer summary of what this model is. -->
17
 
18
  This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
 
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
  - **Language(s) (NLP):** [More Information Needed]
25
  - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
 
28
  ### Model Sources [optional]
29
 
@@ -77,7 +250,8 @@ Use the code below to get started with the model.
77
 
78
  ### Training Data
79
 
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
 
81
 
82
  [More Information Needed]
83
 
 
5
 
6
  # Model Card for Model ID
7
 
8
+ Extract POS Receipt Image Data To JSON Record
 
 
9
 
10
  ## Model Details
11
+ Finetuned Google's PaliGemma Model for Receipt Image extraction to JSON Record.
12
+
13
+ ### Model Usage
14
+
15
+ Setup Environment
16
+ ```
17
+ pip install transformers==4.42.2
18
+ pip install datasets
19
+ pip install peft accelerate bitsandbytes
20
+
21
+ ```
22
+ Specify Device
23
+ ```
24
+ import torch
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ device_map={"":0}
27
+ ```
28
+
29
+ Step-1 Load Image Processor
30
+ ```
31
+ from transformers import AutoProcessor
32
+
33
+ FINETUNED_MODEL_ID = "mychen76/paligemma-receipt-json-3b-mix-448-v2b"
34
+ processor = AutoProcessor.from_pretrained(FINETUNED_MODEL_ID)
35
+ ```
36
+ Step-2 Set Task Prompt
37
+ ```
38
+ TASK_PROMPT = "EXTRACT_JSON_RECEIPT"
39
+ MAX_LENGTH = 512
40
+
41
+ inputs = processor(text=TASK_PROMPT, images=test_image, return_tensors="pt").to(device)
42
+ for k,v in inputs.items():
43
+ print(k,v.shape)
44
+ ```
45
+ Step-3 load model
46
+ ```
47
+ import torch
48
+ from transformers import PaliGemmaForConditionalGeneration
49
+ from transformers import BitsAndBytesConfig
50
+ from transformers import BitsAndBytesConfig
51
+ from peft import get_peft_model, LoraConfig
52
+
53
+ # Load Full model
54
+ model = PaliGemmaForConditionalGeneration.from_pretrained(FINETUNED_MODEL_ID,device_map={"":0})
55
+ ```
56
+ OR Load Quantized
57
+ ```
58
+ # Q-LoRa
59
+ bnb_config = BitsAndBytesConfig(
60
+ load_in_4bit=True,
61
+ bnb_4bit_quant_type="nf4",
62
+ bnb_4bit_compute_type=torch.bfloat16
63
+ )
64
+ lora_config = LoraConfig(
65
+ r=8,
66
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
67
+ task_type="CAUSAL_LM"
68
+ )
69
+ model = PaliGemmaForConditionalGeneration.from_pretrained(FINETUNED_MODEL_ID, quantization_config=bnb_config, device_map={"":0})
70
+ ```
71
+ Step-4 Inference
72
+ ```
73
+ # Autoregressively generate,use greedy decoding here, for more fancy methods see https://huggingface.co/blog/how-to-generate
74
+ generated_ids = model.generate(**inputs, max_new_tokens=MAX_LENGTH)
75
+
76
+ # Next turn each predicted token ID back into a string using the decode method
77
+ # chop of the prompt, which consists of image tokens and text prompt
78
+ image_token_index = model.config.image_token_index
79
+ num_image_tokens = len(generated_ids[generated_ids==image_token_index])
80
+ num_text_tokens = len(processor.tokenizer.encode(PROMPT))
81
+ num_prompt_tokens = num_image_tokens + num_text_tokens + 2
82
+ generated_text = processor.batch_decode(generated_ids[:, num_prompt_tokens:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
83
+ print(generated_text)
84
+ ```
85
+ Result Tokens
86
+ ```
87
+ '<s_total></s_total><s_tips></s_tips><s_time></s_time><s_telephone>(718)308-1118</s_telephone><s_tax></s_tax><s_subtotal></s_subtotal><s_store_name></s_store_name><s_store_addr>Brooklyn,NY11211</s_store_addr><s_line_items><s_item_value>2.98</s_item_value><s_item_quantity>1</s_item_quantity><s_item_name>NORI</s_item_name><s_item_key></s_item_key><sep/><s_item_value>2.35</s_item_value><s_item_quantity>1</s_item_quantity><s_item_name>TOMATOESPLUM</s_item_name><s_item_key></s_item_key><sep/><s_item_value>0.97</s_item_value><s_item_quantity>1</s_item_quantity><s_item_name>ONIONSVIDALIA</s_item_name><s_item_key></s_item_key><sep/><s_item_value>2.48</s_item_value><s_item_quantity>1</s_item_quantity><s_item_name>HAMBURRN</s_item_name><s_item_key></s_item_key><sep/><s_item_value>0.99</s_item_value><s_item_quantity>1</s_item_quantity><s_item_name>FTRAWBERRY</s_item_name><s_item_key></s_item_key><sep/><s_item_value>0.99</s_item_value><s_item_quantity>1</s_item_quantity><s_item_name>FTRAWBERRY</s_item_name><s_item_key></s_item_key><sep/><s_item_value>0.57</s_item_value><s_item_quantity>1</s_item_quantity><s_item_name>PILSNER</'
88
+ ```
89
+ Step-5 Convert Result to Json (borrow from donut model)
90
+ ```
91
+ import re
92
+
93
+ # let's turn that into JSON
94
+ def token2json(tokens, is_inner_value=False, added_vocab=None):
95
+ """
96
+ Convert a (generated) token sequence into an ordered JSON format.
97
+ """
98
+ if added_vocab is None:
99
+ added_vocab = processor.tokenizer.get_added_vocab()
100
+
101
+ output = {}
102
+
103
+ while tokens:
104
+ start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE)
105
+ if start_token is None:
106
+ break
107
+ key = start_token.group(1)
108
+ key_escaped = re.escape(key)
109
+
110
+ end_token = re.search(rf"</s_{key_escaped}>", tokens, re.IGNORECASE)
111
+ start_token = start_token.group()
112
+ if end_token is None:
113
+ tokens = tokens.replace(start_token, "")
114
+ else:
115
+ end_token = end_token.group()
116
+ start_token_escaped = re.escape(start_token)
117
+ end_token_escaped = re.escape(end_token)
118
+ content = re.search(
119
+ f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL
120
+ )
121
+ if content is not None:
122
+ content = content.group(1).strip()
123
+ if r"<s_" in content and r"</s_" in content: # non-leaf node
124
+ value = token2json(content, is_inner_value=True, added_vocab=added_vocab)
125
+ if value:
126
+ if len(value) == 1:
127
+ value = value[0]
128
+ output[key] = value
129
+ else: # leaf nodes
130
+ output[key] = []
131
+ for leaf in content.split(r"<sep/>"):
132
+ leaf = leaf.strip()
133
+ if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>":
134
+ leaf = leaf[1:-2] # for categorical special tokens
135
+ output[key].append(leaf)
136
+ if len(output[key]) == 1:
137
+ output[key] = output[key][0]
138
+
139
+ tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
140
+ if tokens[:6] == r"<sep/>": # non-leaf nodes
141
+ return [output] + token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab)
142
+
143
+ if len(output):
144
+ return [output] if is_inner_value else output
145
+ else:
146
+ return [] if is_inner_value else {"text_sequence": tokens}
147
+
148
+
149
+ ## generated
150
+ generated_json = token2json(generated_text)
151
+ print(generated_json)
152
+ ```
153
+ Final Result in Json
154
+ ```
155
+ [{'total': '',
156
+ 'tips': '',
157
+ 'time': '',
158
+ 'telephone': '(718)308-1118',
159
+ 'tax': '',
160
+ 'subtotal': '',
161
+ 'store_name': '',
162
+ 'store_addr': 'Brooklyn,NY11211',
163
+ 'item_value': '2.98',
164
+ 'item_quantity': '1',
165
+ 'item_name': 'NORI',
166
+ 'item_key': ''},
167
+ {'item_value': '2.35',
168
+ 'item_quantity': '1',
169
+ 'item_name': 'TOMATOESPLUM',
170
+ 'item_key': ''},
171
+ {'item_value': '0.97',
172
+ 'item_quantity': '1',
173
+ 'item_name': 'ONIONSVIDALIA',
174
+ 'item_key': ''},
175
+ {'item_value': '2.48',
176
+ 'item_quantity': '1',
177
+ 'item_name': 'HAMBURRN',
178
+ 'item_key': ''},
179
+ {'item_value': '0.99',
180
+ 'item_quantity': '1',
181
+ 'item_name': 'FTRAWBERRY',
182
+ 'item_key': ''},
183
+ {'item_value': '0.99',
184
+ 'item_quantity': '1',
185
+ 'item_name': 'FTRAWBERRY',
186
+ 'item_key': ''},
187
+ {'item_value': '0.57', 'item_quantity': '1'}]
188
+ ```
189
 
 
190
 
191
  <!-- Provide a longer summary of what this model is. -->
192
 
193
  This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
194
 
195
+ - **Developed by:** [email protected]
196
+ - **Model type:** Vision Model for Receipt Image Data Extraction
 
 
197
  - **Language(s) (NLP):** [More Information Needed]
198
  - **License:** [More Information Needed]
199
+ - **Finetuned from model [optional]:** PaliGemma-3b-pt-224
200
 
201
  ### Model Sources [optional]
202
 
 
250
 
251
  ### Training Data
252
 
253
+ see here: mychen76/invoices-and-receipts_ocr_v1
254
+
255
 
256
  [More Information Needed]
257