BERTA

Модель для расчетов эмбеддингов предложений на русском и английском языках получена методом дистилляции эмбеддингов ai-forever/FRIDA (размер эмбеддингов - 1536, слоёв - 24) в sergeyzh/LaBSE-ru-turbo (размер эмбеддингов - 768, слоёв - 12). Основной режим использования FRIDA - CLS pooling заменен на mean pooling. Каких-либо других изменений поведения модели не производилось. Дистиляция выполнена в максимально возможном объеме - эмбеддинги русских и английских предложений, работа префиксов.

Размер контекста модели соответствует FRIDA - 512 токенов.

Префиксы

Все префиксы унаследованы от FRIDA. Оптимальный (обеспечивающий средние результаты) префикс для большинства задач - "categorize_entailment: " прописан по умолчанию в config_sentence_transformers.json

Перечень используемых префиксов и их влияние на оценки модели в encodechka:

Префикс STS PI NLI SA TI
- 0,842 0,757 0,463 0,830 0,985
search_query: 0,853 0,767 0,479 0,825 0,987
search_document: 0,831 0,749 0,463 0,817 0,986
paraphrase: 0,847 0,778 0,446 0,825 0,986
categorize: 0,857 0,765 0,501 0,829 0,988
categorize_sentiment: 0,589 0,535 0,417 0,805 0,982
categorize_topic: 0,740 0,521 0,396 0,770 0,982
categorize_entailment: 0,841 0,762 0,571 0,827 0,986

Задачи:

  • Semantic text similarity (STS);
  • Paraphrase identification (PI);
  • Natural language inference (NLI);
  • Sentiment analysis (SA);
  • Toxicity identification (TI).

Метрики

Оценки модели на бенчмарке ruMTEB:

Model Name Metric FRIDA BERTA rubert-mini-frida multilingual-e5-large-instruct multilingual-e5-large
CEDRClassification Accuracy 0.646 0.622 0.552 0.500 0.448
GeoreviewClassification Accuracy 0.577 0.548 0.464 0.559 0.497
GeoreviewClusteringP2P V-measure 0.783 0.738 0.698 0.743 0.605
HeadlineClassification Accuracy 0.890 0.891 0.880 0.862 0.758
InappropriatenessClassification Accuracy 0.783 0.748 0.698 0.655 0.616
KinopoiskClassification Accuracy 0.705 0.678 0.595 0.661 0.566
RiaNewsRetrieval NDCG@10 0.868 0.816 0.721 0.824 0.807
RuBQReranking MAP@10 0.771 0.752 0.711 0.717 0.756
RuBQRetrieval NDCG@10 0.724 0.710 0.654 0.692 0.741
RuReviewsClassification Accuracy 0.751 0.723 0.658 0.686 0.653
RuSTSBenchmarkSTS Pearson correlation 0.814 0.822 0.803 0.840 0.831
RuSciBenchGRNTIClassification Accuracy 0.699 0.690 0.625 0.651 0.582
RuSciBenchGRNTIClusteringP2P V-measure 0.670 0.650 0.586 0.622 0.520
RuSciBenchOECDClassification Accuracy 0.546 0.555 0.493 0.502 0.445
RuSciBenchOECDClusteringP2P V-measure 0.566 0.556 0.507 0.528 0.450
SensitiveTopicsClassification Accuracy 0.398 0.399 0.373 0.323 0.257
TERRaClassification Average Precision 0.665 0.657 0.606 0.639 0.584
Model Name Metric FRIDA BERTA rubert-mini-frida multilingual-e5-large-instruct multilingual-e5-large
Classification Accuracy 0.707 0.698 0.631 0.654 0.588
Clustering V-measure 0.673 0.648 0.597 0.631 0.525
MultiLabelClassification Accuracy 0.522 0.510 0.463 0.412 0.353
PairClassification Average Precision 0.665 0.657 0.606 0.639 0.584
Reranking MAP@10 0.771 0.752 0.711 0.717 0.756
Retrieval NDCG@10 0.796 0.763 0.687 0.758 0.774
STS Pearson correlation 0.814 0.822 0.803 0.840 0.831
Average Average 0.707 0.693 0.643 0.664 0.630

Использование модели с библиотекой transformers:

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel


def pool(hidden_state, mask, pooling_method="mean"):
    if pooling_method == "mean":
        s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
        d = mask.sum(axis=1, keepdim=True).float()
        return s / d
    elif pooling_method == "cls":
        return hidden_state[:, 0]

inputs = [
    # 
    "paraphrase: В Ярославской области разрешили работу бань, но без посетителей",
    "categorize_entailment: Женщину доставили в больницу, за ее жизнь сейчас борются врачи.",
    "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
    # 
    "paraphrase: Ярославским баням разрешили работать без посетителей",
    "categorize_entailment: Женщину спасают врачи.",
    "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование."
]

tokenizer = AutoTokenizer.from_pretrained("sergeyzh/BERTA")
model = AutoModel.from_pretrained("sergeyzh/BERTA")

tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt")

with torch.no_grad():
    outputs = model(**tokenized_inputs)
    
embeddings = pool(
    outputs.last_hidden_state, 
    tokenized_inputs["attention_mask"],
    pooling_method="mean"
)

embeddings = F.normalize(embeddings, p=2, dim=1)
sim_scores = embeddings[:3] @ embeddings[3:].T
print(sim_scores.diag().tolist())
# [0.9530372023582458, 0.866746723651886,  0.7839133143424988]
# [0.9360030293464661, 0.8591322302818298, 0.728583037853241] - FRIDA

Использование с sentence_transformers (sentence-transformers>=2.4.0):

from sentence_transformers import SentenceTransformer

# loads model with mean pooling
model = SentenceTransformer("sergeyzh/BERTA")

paraphrase = model.encode(["В Ярославской области разрешили работу бань, но без посетителей", "Ярославским баням разрешили работать без посетителей"], prompt="paraphrase: ")
print(paraphrase[0] @ paraphrase[1].T) 
# 0.9530372
# 0.9360032 - FRIDA

categorize_entailment = model.encode(["Женщину доставили в больницу, за ее жизнь сейчас борются врачи.", "Женщину спасают врачи."], prompt="categorize_entailment: ")
print(categorize_entailment[0] @ categorize_entailment[1].T) 
# 0.8667469
# 0.8591322 - FRIDA

query_embedding = model.encode("Сколько программистов нужно, чтобы вкрутить лампочку?", prompt="search_query: ")
document_embedding = model.encode("Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.", prompt="search_document: ")
print(query_embedding @ document_embedding.T) 
# 0.7839136
# 0.7285831 - FRIDA
Downloads last month
9
Safetensors
Model size
128M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.

Model tree for sergeyzh/BERTA

Finetuned
(2)
this model

Datasets used to train sergeyzh/BERTA