|
|
|
from transformers import AutoModel, AutoTokenizer |
|
from dataset import get_data_for_evaluation |
|
from arguments import get_args |
|
from tqdm import tqdm |
|
import torch |
|
import os |
|
|
|
|
|
def run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer, max_seq_len=512): |
|
|
|
ranked_indices_list = [] |
|
gold_index_list = [] |
|
for doc_id in tqdm(eval_data): |
|
context_list = documents[doc_id] |
|
|
|
with torch.no_grad(): |
|
|
|
context_embs = [] |
|
for chunk in context_list: |
|
chunk_ids = tokenizer(chunk, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda") |
|
|
|
c_emb = context_encoder(input_ids=chunk_ids.input_ids, attention_mask=chunk_ids.attention_mask) |
|
c_emb = c_emb.last_hidden_state[:, 0, :] |
|
context_embs.append(c_emb) |
|
context_embs = torch.cat(context_embs, dim=0) |
|
|
|
sample_list = eval_data[doc_id] |
|
query_embs = [] |
|
for item in sample_list: |
|
gold_idx = item['gold_idx'] |
|
gold_index_list.append(gold_idx) |
|
|
|
query = item['query'] |
|
query_ids = tokenizer(query, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda") |
|
q_emb = query_encoder(input_ids=query_ids.input_ids, attention_mask=query_ids.attention_mask) |
|
q_emb = q_emb.last_hidden_state[:, 0, :] |
|
query_embs.append(q_emb) |
|
query_embs = torch.cat(query_embs, dim=0) |
|
|
|
similarities = query_embs.matmul(context_embs.transpose(0,1)) |
|
ranked_results = torch.argsort(similarities, dim=-1, descending=True) |
|
ranked_indices_list.extend(ranked_results.tolist()) |
|
|
|
return ranked_indices_list, gold_index_list |
|
|
|
|
|
def calculate_recall(ranked_indices_list, gold_index_list, topk): |
|
hit = 0 |
|
for ranked_indices, gold_index in zip(ranked_indices_list, gold_index_list): |
|
for idx in ranked_indices[:topk]: |
|
if idx == gold_index: |
|
hit += 1 |
|
break |
|
recall = hit / len(ranked_indices_list) |
|
|
|
print("top-%d recall score: %.4f" % (topk, recall)) |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.query_encoder_path) |
|
|
|
|
|
query_encoder = AutoModel.from_pretrained(args.query_encoder_path) |
|
context_encoder = AutoModel.from_pretrained(args.context_encoder_path) |
|
query_encoder.to("cuda"), query_encoder.eval() |
|
context_encoder.to("cuda"), context_encoder.eval() |
|
|
|
|
|
if args.eval_dataset == "doc2dial": |
|
input_datapath = os.path.join(args.data_folder, args.doc2dial_datapath) |
|
input_docpath = os.path.join(args.data_folder, args.doc2dial_docpath) |
|
elif args.eval_dataset == "quac": |
|
input_datapath = os.path.join(args.data_folder, args.quac_datapath) |
|
input_docpath = os.path.join(args.data_folder, args.quac_docpath) |
|
elif args.eval_dataset == "qrecc": |
|
input_datapath = os.path.join(args.data_folder, args.qrecc_datapath) |
|
input_docpath = os.path.join(args.data_folder, args.qrecc_docpath) |
|
elif args.eval_dataset == "topiocqa" or args.eval_dataset == "inscit": |
|
raise Exception("We have prepare the function to get queries, but a wikipedia corpus needs to be downloaded") |
|
else: |
|
raise Exception("Please input a correct eval_dataset name!") |
|
|
|
eval_data, documents = get_data_for_evaluation(input_datapath, input_docpath, args.eval_dataset) |
|
|
|
|
|
ranked_indices_list, gold_index_list = run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer) |
|
print("number of the total test samples: %d" % len(ranked_indices_list)) |
|
|
|
|
|
print("evaluating on %s" % args.eval_dataset) |
|
topk_list = [1, 5, 20] |
|
for topk in topk_list: |
|
calculate_recall(ranked_indices_list, gold_index_list, topk=topk) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|