pierreguillou commited on
Commit
c9b4e80
·
1 Parent(s): 70bfccd

Update files/functions.py

Browse files
Files changed (1) hide show
  1. files/functions.py +36 -0
files/functions.py CHANGED
@@ -147,6 +147,42 @@ for lang_t, langcode_t in zip(langs_t,langscode_t):
147
  langdetect2Tesseract = {v:k for k,v in Tesseract2langdetect.items()}
148
 
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # General
151
 
152
  # get text and bounding boxes from an image
 
147
  langdetect2Tesseract = {v:k for k,v in Tesseract2langdetect.items()}
148
 
149
 
150
+ ## model / feature extractor / tokenizer
151
+
152
+ # get device
153
+ import torch
154
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155
+
156
+ ## model LiLT
157
+ import transformers
158
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
159
+ tokenizer_lilt = AutoTokenizer.from_pretrained(model_id_lilt)
160
+ model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
161
+ model_lilt.to(device);
162
+
163
+ ## model LayoutXLM
164
+ from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
165
+ model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
166
+ model_layoutxlm.to(device);
167
+
168
+ # feature extractor
169
+ from transformers import LayoutLMv2FeatureExtractor
170
+ feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
171
+
172
+ # tokenizer
173
+ from transformers import AutoTokenizer
174
+ tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)
175
+
176
+ # get labels
177
+ id2label_lilt = model_lilt.config.id2label
178
+ label2id_lilt = model_lilt.config.label2id
179
+ num_labels_lilt = len(id2label_lilt)
180
+
181
+ id2label_layoutxlm = model_layoutxlm.config.id2label
182
+ label2id_layoutxlm = model_layoutxlm.config.label2id
183
+ num_labels_layoutxlm = len(id2label_layoutxlm)
184
+
185
+
186
  # General
187
 
188
  # get text and bounding boxes from an image