Commit
·
f3b1f58
1
Parent(s):
aa3b56b
Update README.md
Browse files
README.md
CHANGED
@@ -4,10 +4,59 @@ pipeline_tag: zero-shot-classification
|
|
4 |
tags:
|
5 |
- rubert
|
6 |
- russian
|
|
|
|
|
|
|
7 |
widget:
|
8 |
- text: "Я хочу поехать в Австралию"
|
9 |
candidate_labels: "спорт,путешествия,музыка,кино,книги,наука,политика"
|
10 |
hypothesis_template: "Тема текста - {}."
|
11 |
---
|
12 |
-
# RuBERT base model (cased) for NLI
|
13 |
-
The model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
tags:
|
5 |
- rubert
|
6 |
- russian
|
7 |
+
- nli
|
8 |
+
- rte
|
9 |
+
- zero-shot-classification
|
10 |
widget:
|
11 |
- text: "Я хочу поехать в Австралию"
|
12 |
candidate_labels: "спорт,путешествия,музыка,кино,книги,наука,политика"
|
13 |
hypothesis_template: "Тема текста - {}."
|
14 |
---
|
15 |
+
# RuBERT base model (cased) fine-tuned for NLI (natural language inference)
|
16 |
+
The model has been trained on a series of NLI datasets automatically translated to Russian from English [from this repo](https://github.com/felipessalvatore/NLI_datasets).
|
17 |
+
|
18 |
+
It predicts the logical relationship between two short texts: entailment, contradiction, or neutral.
|
19 |
+
|
20 |
+
|
21 |
+
How to run the model for NLI:
|
22 |
+
```python
|
23 |
+
# !pip install transformers sentencepiece --quiet
|
24 |
+
import torch
|
25 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
26 |
+
|
27 |
+
model_checkpoint = 'cointegrated/rubert-base-cased-nli-threeway'
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
29 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
model.cuda()
|
32 |
+
|
33 |
+
text1 = 'Сократ - человек, а все люди смертны.'
|
34 |
+
text2 = 'Сократ никогда не умрёт.'
|
35 |
+
with torch.inference_mode():
|
36 |
+
out = model(**tokenizer(text1, text2, return_tensors='pt').to(model.device))
|
37 |
+
proba = torch.softmax(out.logits, -1).cpu().numpy()[0]
|
38 |
+
print({v: proba[k] for k, v in model.config.id2label.items()})
|
39 |
+
# {'entailment': 0.009525929, 'contradiction': 0.9332064, 'neutral': 0.05726764}
|
40 |
+
```
|
41 |
+
|
42 |
+
You can also use this model for zero-shot short text classification (by labels only), e.g. for sentiment analysis:
|
43 |
+
|
44 |
+
```python
|
45 |
+
def predict_zero_shot(text, label_texts, model, tokenizer, label='entailment', normalize=True):
|
46 |
+
label_texts
|
47 |
+
tokens = tokenizer([text] * len(label_texts), label_texts, truncation=True, return_tensors='pt', padding=True)
|
48 |
+
with torch.inference_mode():
|
49 |
+
result = torch.softmax(model(**tokens.to(model.device)).logits, -1)
|
50 |
+
proba = result[:, model.config.label2id[label]].cpu().numpy()
|
51 |
+
if normalize:
|
52 |
+
proba /= sum(proba)
|
53 |
+
return proba
|
54 |
+
|
55 |
+
classes = ['Я доволен', 'Я недоволен']
|
56 |
+
predict_zero_shot('Какая гадость эта ваша заливная рыба!', classes, model, tokenizer)
|
57 |
+
# array([0.05609814, 0.9439019 ], dtype=float32)
|
58 |
+
predict_zero_shot('Какая вкусная эта ваша заливная рыба!', classes, model, tokenizer)
|
59 |
+
# array([0.9059292 , 0.09407079], dtype=float32)
|
60 |
+
```
|
61 |
+
|
62 |
+
Alternatively, you can use [Huggingface pipelines](https://huggingface.co/transformers/main_classes/pipelines.html) for inference.
|