|
--- |
|
license: apache-2.0 |
|
library_name: transformers |
|
pipeline_tag: text-classification |
|
tags: |
|
- reranker |
|
- cross-encoder |
|
- dual-passage-classifier |
|
- multi-hop-qa |
|
- pytorch |
|
datasets: |
|
- musique |
|
model-index: |
|
- name: Dual Passage Classifier |
|
results: |
|
- task: |
|
type: text-classification |
|
name: Passage Reranking (Multi-hop QA) |
|
dataset: |
|
name: MuSiQue-Full |
|
type: musique |
|
split: validation |
|
metrics: |
|
- type: MAP |
|
value: TODO |
|
name: Mean Average Precision |
|
--- |
|
|
|
# Dual Passage Classifier (DPC) for Multi-hop QA 🔎📑 |
|
|
|
**Dual Passage Classifier (DPC)** 是一個 *cross-encoder* reranker, |
|
輸入 **(Question, Passage 1, Passage 2)**,輸出一維分數,判斷這對段落對回答該問題的「共同貢獻程度」: |
|
|
|
| label | 定義 | Margin label | |
|
|-------|------|--------------| |
|
| positive | d1 & d2 都必要 | 0 | |
|
| neutral | 其中一段必要 |1 | |
|
| negative | 都不重要 |2 | |
|
|
|
本模型將 **三段式 MarginRankingLoss** 混合, |
|
並在 MuSiQue 上把 baseline(`naver/trecdl22-crossencoder-debertav3`)的模型。 |
|
這是訓練在 Musique Dataset 的模型 |
|
--- |
|
|
|
## Quick-start |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch, torch.nn.functional as F |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("QQhahaha/musique-deberta-v3-large-MLP-2-Marginloss-ratio7") |
|
model = AutoModel.from_pretrained("QQhahaha/musique-deberta-v3-large-MLP-2-Marginloss-ratio7").eval() |
|
|
|
def score(q, d1, d2): |
|
text = f"{q} [SEP] {d1} [SEP] {d2}" |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
with torch.no_grad(): |
|
logit = model(**inputs).last_hidden_state[:,0,:] @ model.classifier[0].weight.T |
|
return torch.sigmoid(logit).item() # 0~1,越高代表正向 |
|
|