File size: 11,883 Bytes
8de55e4
19b279b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8de55e4
19b279b
 
8de55e4
19b279b
 
 
 
 
 
 
 
 
32a8238
19b279b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32a8238
 
19b279b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32a8238
19b279b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
---

language:
- ru
- en

pipeline_tag: sentence-similarity

tags:
- russian
- pretraining
- embeddings
- tiny
- feature-extraction
- sentence-similarity
- sentence-transformers
- transformers

datasets:
- IlyaGusev/gazeta
- zloelias/lenta-ru
- HuggingFaceFW/fineweb-2
- HuggingFaceFW/fineweb

license: mit
base_model: sergeyzh/rubert-mini-sts

---


## rubert-mini-frida - лёгкая и быстрая модификация FRIDA

Модель для расчетов эмбеддингов предложений на русском и английском языках получена методом дистилляции эмбеддингов [ai-forever/FRIDA](https://huggingface.co/ai-forever/FRIDA) (размер эмбеддингов - 1536, слоёв - 24) в [sergeyzh/rubert-mini-sts](https://huggingface.co/sergeyzh/rubert-mini-sts) (размер эмбеддингов - 312, слоёв - 7). Основной режим использования FRIDA - CLS pooling заменен на mean pooling. Каких-либо других  изменений поведения модели (модификации или фильтрации эмбеддингов, использования дополнительной модели) не производилось. Дистиляция выполнена в максимально возможном объеме - эмбеддинги русских и английских предложений, работа префиксов. 

Рекомендуемый размер контекста модели соответствует FRIDA и не превышает 512 токенов (фактический унаследованный от исходной модели - 2048).

## Префиксы
Все префиксы унаследованы от FRIDA. 
Оптимальный (обеспечивающий средние результаты) для большинства задач - "categorize: " прописан по умолчанию в [config_sentence_transformers.json](https://huggingface.co/sergeyzh/rubert-mini-frida/blob/main/config_sentence_transformers.json)

Перечень используемых префиксов и их влияние на оценки модели в [encodechka](https://github.com/avidale/encodechka):

| Префикс                | STS       | PI        | NLI       | SA        | TI        |
|:-----------------------|:---------:|:---------:|:---------:|:---------:|:---------:|
| -                      |   0.839   |   0.762   |   0.475   |   0.801   |   0.972   |
| search_query:          |   0.846   |   0.761   |   0.498   |   0.800   |   0.973   |

| search_document:       |   0.830   |   0.748   |   0.468   |   0.794   |   0.972   |
| paraphrase:            |   0.835   | **0.764** |   0.475   |   0.799   |   0.973   |
| categorize:            | **0.850** |   0.761   |   0.516   |   0.802   | **0.973** |
| categorize_sentiment:  |   0.755   |   0.656   |   0.427   |   0.798   |   0.959   |

| categorize_topic:      |   0.734   |   0.523   |   0.389   |   0.728   |   0.959   |
| categorize_entailment: |   0.837   |   0.753   | **0.544** | **0.802** |   0.970   |





**Задачи:**



- Semantic text similarity (**STS**);

- Paraphrase identification (**PI**);

- Natural language inference (**NLI**);

- Sentiment analysis (**SA**);

- Toxicity identification (**TI**).







# Метрики

Оценки модели на бенчмарке [ruMTEB](https://habr.com/ru/companies/sberdevices/articles/831150/):



|Model Name                         | Metric              | Frida                  | rubert-mini-frida   | multilingual-e5-large-instruct | multilingual-e5-large |

|:----------------------------------|:--------------------|-----------------------:|--------------------:|---------------------:|----------------------:|

|CEDRClassification                 | Accuracy            |       **0.646**        |         0.552       |        0.500         |         0.448         |

|GeoreviewClassification            | Accuracy            |       **0.577**        |         0.464       |        0.559         |         0.497         |

|GeoreviewClusteringP2P             | V-measure           |       **0.783**        |         0.698       |        0.743         |         0.605         |

|HeadlineClassification             | Accuracy            |       **0.890**        |         0.880       |        0.862         |         0.758         |

|InappropriatenessClassification    | Accuracy            |       **0.783**        |         0.698       |        0.655         |         0.616         |

|KinopoiskClassification            | Accuracy            |       **0.705**        |         0.595       |        0.661         |         0.566         |

|RiaNewsRetrieval                   | NDCG@10             |       **0.868**        |         0.721       |        0.824         |         0.807         |

|RuBQReranking                      | MAP@10              |       **0.771**        |         0.711       |        0.717         |         0.756         |

|RuBQRetrieval                      | NDCG@10             |         0.724          |         0.654       |        0.692         |       **0.741**       |

|RuReviewsClassification            | Accuracy            |       **0.751**        |         0.658       |        0.686         |         0.653         |

|RuSTSBenchmarkSTS                  | Pearson correlation |         0.814          |         0.803       |      **0.840**       |         0.831         |

|RuSciBenchGRNTIClassification      | Accuracy            |       **0.699**        |         0.625       |        0.651         |         0.582         |

|RuSciBenchGRNTIClusteringP2P       | V-measure           |       **0.670**        |         0.586       |        0.622         |         0.520         |

|RuSciBenchOECDClassification       | Accuracy            |       **0.546**        |         0.493       |        0.502         |         0.445         |

|RuSciBenchOECDClusteringP2P        | V-measure           |       **0.566**        |         0.507       |        0.528         |         0.450         |

|SensitiveTopicsClassification      | Accuracy            |       **0.398**        |         0.373       |        0.323         |         0.257         |

|TERRaClassification                | Average Precision   |       **0.665**        |         0.606       |        0.639         |         0.584         |



|Model Name                         | Metric              | Frida                  | rubert-mini-frida   | multilingual-e5-large-instruct | multilingual-e5-large |

|:----------------------------------|:--------------------|-----------------------:|--------------------:|----------------------:|---------------------:|

|Classification                     | Accuracy            |       **0.707**        |        0.631        |        0.654          |        0.588         |

|Clustering                         | V-measure           |       **0.673**        |        0.597        |        0.631          |        0.525         |

|MultiLabelClassification           | Accuracy            |       **0.522**        |        0.463        |        0.412          |        0.353         |

|PairClassification                 | Average Precision   |       **0.665**        |        0.606        |        0.639          |        0.584         |

|Reranking                          | MAP@10              |       **0.771**        |        0.711        |        0.717          |        0.756         |

|Retrieval                          | NDCG@10             |       **0.796**        |        0.687        |        0.758          |        0.774         |

|STS                                | Pearson correlation |         0.814          |        0.803        |      **0.840**        |        0.831         |

|Average                            | Average             |       **0.707**        |        0.643        |        0.664          |        0.630         |







## Использование модели с библиотекой `transformers`:



```python

import torch

import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel





def pool(hidden_state, mask, pooling_method="mean"):

    if pooling_method == "mean":
        s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)

        d = mask.sum(axis=1, keepdim=True).float()

        return s / d

    elif pooling_method == "cls":

        return hidden_state[:, 0]


inputs = [
    # 

    "paraphrase: В Ярославской области разрешили работу бань, но без посетителей",

    "categorize_entailment: Женщину доставили в больницу, за ее жизнь сейчас борются врачи.",

    "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",

    # 

    "paraphrase: Ярославским баням разрешили работать без посетителей",

    "categorize_entailment: Женщину спасают врачи.",

    "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование."

]


tokenizer = AutoTokenizer.from_pretrained("sergeyzh/rubert-mini-frida")

model = AutoModel.from_pretrained("sergeyzh/rubert-mini-frida")

tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt")



with torch.no_grad():
    outputs = model(**tokenized_inputs)

    

embeddings = pool(

    outputs.last_hidden_state, 

    tokenized_inputs["attention_mask"],

    pooling_method="mean"

)


embeddings = F.normalize(embeddings, p=2, dim=1)
sim_scores = embeddings[:3] @ embeddings[3:].T

print(sim_scores.diag().tolist())
# [0.9423348903656006, 0.8306248188018799, 0.7095720767974854]
# [0.9360030293464661, 0.8591322302818298, 0.728583037853241] - FRIDA
```





## Использование с `sentence_transformers` (sentence-transformers>=2.4.0):



```python

from sentence_transformers import SentenceTransformer



# loads model with mean pooling

model = SentenceTransformer("sergeyzh/rubert-mini-frida")



paraphrase = model.encode(["В Ярославской области разрешили работу бань, но без посетителей", "Ярославским баням разрешили работать без посетителей"], prompt="paraphrase: ")

print(paraphrase[0] @ paraphrase[1].T) 

# 0.94233495

# 0.9360032 - FRIDA



categorize_entailment = model.encode(["Женщину доставили в больницу, за ее жизнь сейчас борются врачи.", "Женщину спасают врачи."], prompt="categorize_entailment: ")

print(categorize_entailment[0] @ categorize_entailment[1].T) 

# 0.8306249

# 0.8591322 - FRIDA



query_embedding = model.encode("Сколько программистов нужно, чтобы вкрутить лампочку?", prompt="search_query: ")

document_embedding = model.encode("Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.", prompt="search_document: ")

print(query_embedding @ document_embedding.T) 

# 0.70957196

# 0.7285831 - FRIDA

```