mbert-rerank-base / README.md
tiennv's picture
Create README.md
18c91d4
metadata
datasets:
  - unicamp-dl/mmarco
language:
  - en
  - vi
pipeline_tag: text-classification

Purpose: This is a pythera/mbert-rerank-base module that takes a search query [1] and a passage [2] from Retrieval model and calculates if the passage matches the query.

Support language

Languages: English, Vietnamese

Usage

import torch
from typing import Tuple
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Encoding pair sentence
def create_pairs(pair, max_length):
    s1 = tokenizer(pairs[0], padding=False, truncation=True, return_token_type_ids=True,
                  return_attention_mask=False, return_special_tokens_mask=False)
    s2 = tokenizer(pairs[1], padding=True, truncation=True, return_token_type_ids=True,
                  return_attention_mask=False, return_special_tokens_mask=False)

    return tokenizer.prepare_for_model(s1, s2, truncation='only_first', max_length=max_length, return_tensors='pt')

# Encode text
def rerank(pair:Tuple):
    assert isinstance(pair, Tuple)
    # Tokenize sentences
    encoded_pair = create_pairs(pair)

    # Compute token embeddings
    with torch.no_grad():
      score = model(**encoded_pair, return_dict=True).logits

    return score

# Prepare pair text rerank
pair = ('I come from Vietnam', 'I am from Vietnam')
# Load model from HuggingFace Hub
model = AutoModelForSequenceClassification.from_pretrained('pythera/mbert-rerank-base')
tokenizer = AutoTokenizer.from_pretrained('pythera/mbert-rerank-base')

# Encode docs
score = rerank(pair)
print('Pair score: ', score)

Evaluation

We evaluate our research on the mMARCO (vi) passage ranking task with several methods:

Model Trained Datasets MRR@10
mMiniLM-rerankers MSMACRO 24.7
mT5-rerankers MSMACRO 25.6
mbert-rerankers (our) MSMACRO 35.0