QQhahaha's picture
Update README.md
641da22 verified
metadata
license: apache-2.0
library_name: transformers
pipeline_tag: text-classification
tags:
  - reranker
  - cross-encoder
  - dual-passage-classifier
  - multi-hop-qa
  - pytorch
datasets:
  - musique
model-index:
  - name: Dual Passage Classifier
    results:
      - task:
          type: text-classification
          name: Passage Reranking (Multi-hop QA)
        dataset:
          name: MuSiQue-Full
          type: musique
          split: validation
        metrics:
          - type: MAP
            value: TODO
            name: Mean Average Precision

Dual Passage Classifier (DPC) for Multi-hop QA 🔎📑

Dual Passage Classifier (DPC) 是一個 cross-encoder reranker,
輸入 **(Question, Passage 1, Passage 2)**,輸出一維分數,判斷這對段落對回答該問題的「共同貢獻程度」:

label 定義 Margin label
positive d1 & d2 都必要 0
neutral 其中一段必要 1
negative 都不重要 2

本模型將 三段式 MarginRankingLoss 混合,
並在 MuSiQue 上把 baseline(naver/trecdl22-crossencoder-debertav3)的模型。 這是訓練在 Musique Dataset 的模型

Quick-start

from transformers import AutoTokenizer, AutoModel
import torch, torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("QQhahaha/musique-deberta-v3-large-MLP-2-Marginloss-ratio7")
model     = AutoModel.from_pretrained("QQhahaha/musique-deberta-v3-large-MLP-2-Marginloss-ratio7").eval()

def score(q, d1, d2):
    text = f"{q} [SEP] {d1} [SEP] {d2}"
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        logit = model(**inputs).last_hidden_state[:,0,:] @ model.classifier[0].weight.T
    return torch.sigmoid(logit).item()         # 0~1,越高代表正向