Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
library_name: transformers
|
4 |
+
pipeline_tag: text-classification
|
5 |
+
tags:
|
6 |
+
- reranker
|
7 |
+
- cross-encoder
|
8 |
+
- dual-passage-classifier
|
9 |
+
- multi-hop-qa
|
10 |
+
- pytorch
|
11 |
+
datasets:
|
12 |
+
- 2wiki
|
13 |
+
model-index:
|
14 |
+
- name: Dual Passage Classifier
|
15 |
+
results:
|
16 |
+
- task:
|
17 |
+
type: text-classification
|
18 |
+
name: Passage Reranking (Multi-hop QA)
|
19 |
+
dataset:
|
20 |
+
name: 2wikimultihopQA
|
21 |
+
type: 2wiki
|
22 |
+
split: validation
|
23 |
+
metrics:
|
24 |
+
- type: MAP
|
25 |
+
value: TODO
|
26 |
+
name: Mean Average Precision
|
27 |
+
---
|
28 |
+
|
29 |
+
# Dual Passage Classifier (DPC) for Multi-hop QA 🔎📑
|
30 |
+
|
31 |
+
**Dual Passage Classifier (DPC)** 是一個 *cross-encoder* reranker,
|
32 |
+
輸入 **(Question, Passage 1, Passage 2)**,輸出一維分數,判斷這對段落對回答該問題的「共同貢獻程度」:
|
33 |
+
|
34 |
+
| label | 定義 | Margin label |
|
35 |
+
|-------|------|--------------|
|
36 |
+
| positive | d1 & d2 都必要 | 0 |
|
37 |
+
| neutral | 其中一段必要 |1 |
|
38 |
+
| negative | 都不重要 | 2 |
|
39 |
+
|
40 |
+
本模型將 **三段式 MarginRankingLoss** 混合,
|
41 |
+
並在 2wiki 上把 baseline(`naver/trecdl22-crossencoder-debertav3`)的模型。
|
42 |
+
這是訓練在 2wiki dataset 上的模型
|
43 |
+
---
|
44 |
+
|
45 |
+
## Quick-start
|
46 |
+
|
47 |
+
```python
|
48 |
+
from transformers import AutoTokenizer, AutoModel
|
49 |
+
import torch, torch.nn.functional as F
|
50 |
+
|
51 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/dual-passage-classifier")
|
52 |
+
model = AutoModel.from_pretrained("your-username/dual-passage-classifier").eval()
|
53 |
+
|
54 |
+
def score(q, d1, d2):
|
55 |
+
text = f"{q} [SEP] {d1} [SEP] {d2}"
|
56 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
57 |
+
with torch.no_grad():
|
58 |
+
logit = model(**inputs).last_hidden_state[:,0,:] @ model.classifier[0].weight.T
|
59 |
+
return torch.sigmoid(logit).item() # 0~1,越高代表正向
|