Cricles commited on
Commit
5d567f2
·
verified ·
1 Parent(s): 2774079

Upload 5 files

Browse files
model_wrapper/__init__.py ADDED
File without changes
model_wrapper/bert_wrapper.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification
3
+ from transformers import BertTokenizerFast
4
+
5
+
6
+ class BertWrapper(object):
7
+ MODELS_DIR: str = "new_models/"
8
+ MODEL_NAME: str = "model"
9
+ TOKENIZER: str = "tokenizer"
10
+
11
+ def __init__(self) -> None:
12
+ self.model = AutoModelForSequenceClassification.from_pretrained(
13
+ BertWrapper.MODELS_DIR + BertWrapper.MODEL_NAME, torchscript=True
14
+ )
15
+ self.tokenizer = BertTokenizerFast.from_pretrained(
16
+ BertWrapper.MODELS_DIR + BertWrapper.TOKENIZER
17
+ )
18
+ self.id2label: dict[int, str] = {0: "__label__positive", 1: "__label__negative"}
19
+
20
+ @torch.no_grad()
21
+ def __call__(self, text: str) -> str:
22
+ max_input_length = (
23
+ self.model.config.max_position_embeddings
24
+ ) # 512 for this model
25
+ inputs = self.tokenizer(
26
+ text,
27
+ max_length=max_input_length,
28
+ padding=True,
29
+ truncation=True,
30
+ return_tensors="pt",
31
+ )
32
+ outputs = self.model(
33
+ **inputs, return_dict=True
34
+ ) # output is logits for huggingfcae transformers
35
+ predicted = torch.nn.functional.softmax(outputs.logits, dim=1)
36
+ predicted_id = torch.argmax(predicted, dim=1).numpy()[0]
37
+ return self.id2label[predicted_id]
model_wrapper/fasttext_wrapper.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import fasttext
2
+
3
+ class FasttextWrapper(object):
4
+ MODEL_PATH: str = "./model_fasttext.bin"
5
+
6
+ def __init__(self) -> None:
7
+ self.model = fasttext.load_model(FasttextWrapper.MODEL_PATH)
8
+
9
+ def __call__(self, text: str) -> str:
10
+ return self.model.predict(text)
model_wrapper/frida_wrapper.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Захреначьте импорты сюда и сделайте метод call, как в других обёртках
2
+
3
+ class FridaWrapper(object):
4
+ def __init__(self) -> None:
5
+ pass
6
+
7
+ def __call__(self, text: str) -> str:
8
+ pass
model_wrapper/model_wrapper.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bert_wrapper import BertWrapper
2
+ from fasttext_wrapper import FasttextWrapper
3
+ from frida_wrapper import FridaWrapper
4
+
5
+ from typing import Any
6
+
7
+ class ModelWrapper(object):
8
+ def __init__(self) -> None:
9
+ self.models_dict: dict[str, Any] = {
10
+ "fasttext": FasttextWrapper(),
11
+ "ru-BERT": BertWrapper(),
12
+ "FRIDA": FridaWrapper(),
13
+ }
14
+
15
+ def __call__(self, text: str, model_name: str) -> str:
16
+ return self.models_dict[model_name](text)