metadata
license: apache-2.0
library_name: transformers
pipeline_tag: text-classification
tags:
- reranker
- cross-encoder
- dual-passage-classifier
- multi-hop-qa
- pytorch
datasets:
- musique
model-index:
- name: Dual Passage Classifier
results:
- task:
type: text-classification
name: Passage Reranking (Multi-hop QA)
dataset:
name: MuSiQue-Full
type: musique
split: validation
metrics:
- type: MAP
value: TODO
name: Mean Average Precision
Dual Passage Classifier (DPC) for Multi-hop QA 🔎📑
Dual Passage Classifier (DPC) 是一個 cross-encoder reranker,
輸入 **(Question, Passage 1, Passage 2)**,輸出一維分數,判斷這對段落對回答該問題的「共同貢獻程度」:
label | 定義 | Margin label |
---|---|---|
positive | d1 & d2 都必要 | 0 |
neutral | 其中一段必要 | 1 |
negative | 都不重要 | 2 |
本模型將 三段式 MarginRankingLoss 混合,
並在 MuSiQue 上把 baseline(naver/trecdl22-crossencoder-debertav3
)的模型。
這是訓練在 Musique Dataset 的模型
Quick-start
from transformers import AutoTokenizer, AutoModel
import torch, torch.nn.functional as F
tokenizer = AutoTokenizer.from_pretrained("QQhahaha/musique-deberta-v3-large-MLP-2-Marginloss-ratio7")
model = AutoModel.from_pretrained("QQhahaha/musique-deberta-v3-large-MLP-2-Marginloss-ratio7").eval()
def score(q, d1, d2):
text = f"{q} [SEP] {d1} [SEP] {d2}"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logit = model(**inputs).last_hidden_state[:,0,:] @ model.classifier[0].weight.T
return torch.sigmoid(logit).item() # 0~1,越高代表正向