RAG
Collection
4 items
•
Updated
Purpose: This is a pythera/mbert-rerank-base module that takes a search query [1] and a passage [2] from Retrieval model and calculates if the passage matches the query.
Languages: English, Vietnamese
import torch
from typing import Tuple
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Encoding pair sentence
def create_pairs(pair, max_length):
s1 = tokenizer(pairs[0], padding=False, truncation=True, return_token_type_ids=True,
return_attention_mask=False, return_special_tokens_mask=False)
s2 = tokenizer(pairs[1], padding=True, truncation=True, return_token_type_ids=True,
return_attention_mask=False, return_special_tokens_mask=False)
return tokenizer.prepare_for_model(s1, s2, truncation='only_first', max_length=max_length, return_tensors='pt')
# Encode text
def rerank(pair:Tuple):
assert isinstance(pair, Tuple)
# Tokenize sentences
encoded_pair = create_pairs(pair)
# Compute token embeddings
with torch.no_grad():
score = model(**encoded_pair, return_dict=True).logits
return score
# Prepare pair text rerank
pair = ('I come from Vietnam', 'I am from Vietnam')
# Load model from HuggingFace Hub
model = AutoModelForSequenceClassification.from_pretrained('pythera/mbert-rerank-base')
tokenizer = AutoTokenizer.from_pretrained('pythera/mbert-rerank-base')
# Encode docs
score = rerank(pair)
print('Pair score: ', score)
We evaluate our research on the mMARCO (vi) passage ranking task with several methods:
Model | Trained Datasets | MRR@10 |
---|---|---|
mMiniLM-rerankers | MSMACRO | 24.7 |
mT5-rerankers | MSMACRO | 25.6 |
mbert-rerankers (our) | MSMACRO | 35.0 |