File size: 4,198 Bytes
14ca076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

from transformers import AutoModel, AutoTokenizer
from dataset import get_data_for_evaluation
from arguments import get_args
from tqdm import tqdm
import torch
import os


def run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer, max_seq_len=512):
    
    ranked_indices_list = []
    gold_index_list = []
    for doc_id in tqdm(eval_data):
        context_list = documents[doc_id]
        
        with torch.no_grad():
            # get chunk embeddings
            context_embs = []
            for chunk in context_list:
                chunk_ids = tokenizer(chunk, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")

                c_emb = context_encoder(input_ids=chunk_ids.input_ids, attention_mask=chunk_ids.attention_mask)
                c_emb = c_emb.last_hidden_state[:, 0, :]
                context_embs.append(c_emb)
            context_embs = torch.cat(context_embs, dim=0)   # (num_chunk, hidden_dim)
    
            sample_list = eval_data[doc_id]
            query_embs = []
            for item in sample_list:
                gold_idx = item['gold_idx']
                gold_index_list.append(gold_idx)

                query = item['query']
                query_ids = tokenizer(query, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")
                q_emb = query_encoder(input_ids=query_ids.input_ids, attention_mask=query_ids.attention_mask)
                q_emb = q_emb.last_hidden_state[:, 0, :]
                query_embs.append(q_emb)
            query_embs = torch.cat(query_embs, dim=0)   # (num_query, hidden_dim)

            similarities = query_embs.matmul(context_embs.transpose(0,1))     # (num_query, num_chunk)
            ranked_results = torch.argsort(similarities, dim=-1, descending=True)   # (num_query, num_chunk)
            ranked_indices_list.extend(ranked_results.tolist())

    return ranked_indices_list, gold_index_list


def calculate_recall(ranked_indices_list, gold_index_list, topk):
    hit = 0
    for ranked_indices, gold_index in zip(ranked_indices_list, gold_index_list):
        for idx in ranked_indices[:topk]:
            if idx == gold_index:
                hit += 1
                break
    recall = hit / len(ranked_indices_list)

    print("top-%d recall score: %.4f" % (topk, recall))


def main():
    args = get_args()

    ## get tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.query_encoder_path)

    ## get retriever model
    query_encoder = AutoModel.from_pretrained(args.query_encoder_path)
    context_encoder = AutoModel.from_pretrained(args.context_encoder_path)
    query_encoder.to("cuda"), query_encoder.eval()
    context_encoder.to("cuda"), context_encoder.eval()

    ## get evaluation data
    if args.eval_dataset == "doc2dial":
        input_datapath = os.path.join(args.data_folder, args.doc2dial_datapath)
        input_docpath = os.path.join(args.data_folder, args.doc2dial_docpath)
    elif args.eval_dataset == "quac":
        input_datapath = os.path.join(args.data_folder, args.quac_datapath)
        input_docpath = os.path.join(args.data_folder, args.quac_docpath)
    elif args.eval_dataset == "qrecc":
        input_datapath = os.path.join(args.data_folder, args.qrecc_datapath)
        input_docpath = os.path.join(args.data_folder, args.qrecc_docpath)
    elif args.eval_dataset == "topiocqa" or args.eval_dataset == "inscit":
        raise Exception("We have prepare the function to get queries, but a wikipedia corpus needs to be downloaded")
    else:
        raise Exception("Please input a correct eval_dataset name!")

    eval_data, documents = get_data_for_evaluation(input_datapath, input_docpath, args.eval_dataset)
    
    ## run retrieval
    ranked_indices_list, gold_index_list = run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer)
    print("number of the total test samples: %d" % len(ranked_indices_list))

    ## calculate recall scores
    print("evaluating on %s" % args.eval_dataset)
    topk_list = [1, 5, 20]
    for topk in topk_list:
        calculate_recall(ranked_indices_list, gold_index_list, topk=topk)


if __name__ == "__main__":
    main()