File size: 4,198 Bytes
14ca076 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from transformers import AutoModel, AutoTokenizer
from dataset import get_data_for_evaluation
from arguments import get_args
from tqdm import tqdm
import torch
import os
def run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer, max_seq_len=512):
ranked_indices_list = []
gold_index_list = []
for doc_id in tqdm(eval_data):
context_list = documents[doc_id]
with torch.no_grad():
# get chunk embeddings
context_embs = []
for chunk in context_list:
chunk_ids = tokenizer(chunk, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")
c_emb = context_encoder(input_ids=chunk_ids.input_ids, attention_mask=chunk_ids.attention_mask)
c_emb = c_emb.last_hidden_state[:, 0, :]
context_embs.append(c_emb)
context_embs = torch.cat(context_embs, dim=0) # (num_chunk, hidden_dim)
sample_list = eval_data[doc_id]
query_embs = []
for item in sample_list:
gold_idx = item['gold_idx']
gold_index_list.append(gold_idx)
query = item['query']
query_ids = tokenizer(query, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")
q_emb = query_encoder(input_ids=query_ids.input_ids, attention_mask=query_ids.attention_mask)
q_emb = q_emb.last_hidden_state[:, 0, :]
query_embs.append(q_emb)
query_embs = torch.cat(query_embs, dim=0) # (num_query, hidden_dim)
similarities = query_embs.matmul(context_embs.transpose(0,1)) # (num_query, num_chunk)
ranked_results = torch.argsort(similarities, dim=-1, descending=True) # (num_query, num_chunk)
ranked_indices_list.extend(ranked_results.tolist())
return ranked_indices_list, gold_index_list
def calculate_recall(ranked_indices_list, gold_index_list, topk):
hit = 0
for ranked_indices, gold_index in zip(ranked_indices_list, gold_index_list):
for idx in ranked_indices[:topk]:
if idx == gold_index:
hit += 1
break
recall = hit / len(ranked_indices_list)
print("top-%d recall score: %.4f" % (topk, recall))
def main():
args = get_args()
## get tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.query_encoder_path)
## get retriever model
query_encoder = AutoModel.from_pretrained(args.query_encoder_path)
context_encoder = AutoModel.from_pretrained(args.context_encoder_path)
query_encoder.to("cuda"), query_encoder.eval()
context_encoder.to("cuda"), context_encoder.eval()
## get evaluation data
if args.eval_dataset == "doc2dial":
input_datapath = os.path.join(args.data_folder, args.doc2dial_datapath)
input_docpath = os.path.join(args.data_folder, args.doc2dial_docpath)
elif args.eval_dataset == "quac":
input_datapath = os.path.join(args.data_folder, args.quac_datapath)
input_docpath = os.path.join(args.data_folder, args.quac_docpath)
elif args.eval_dataset == "qrecc":
input_datapath = os.path.join(args.data_folder, args.qrecc_datapath)
input_docpath = os.path.join(args.data_folder, args.qrecc_docpath)
elif args.eval_dataset == "topiocqa" or args.eval_dataset == "inscit":
raise Exception("We have prepare the function to get queries, but a wikipedia corpus needs to be downloaded")
else:
raise Exception("Please input a correct eval_dataset name!")
eval_data, documents = get_data_for_evaluation(input_datapath, input_docpath, args.eval_dataset)
## run retrieval
ranked_indices_list, gold_index_list = run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer)
print("number of the total test samples: %d" % len(ranked_indices_list))
## calculate recall scores
print("evaluating on %s" % args.eval_dataset)
topk_list = [1, 5, 20]
for topk in topk_list:
calculate_recall(ranked_indices_list, gold_index_list, topk=topk)
if __name__ == "__main__":
main()
|