zihanliu commited on
Commit
14ca076
1 Parent(s): bf650c8

Upload 4 files

Browse files
evaluation/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ ### Commands for running evaluation
3
+
4
+ ```console
5
+ python evaluate.py --eval-dataset doc2dial
6
+ python evaluate.py --eval-dataset quac
7
+ python evaluate.py --eval-dataset qrecc
8
+ ```
evaluation/arguments.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import os
4
+
5
+ def get_args():
6
+ parser = argparse.ArgumentParser(description="Dragon-multiturn")
7
+
8
+ parser.add_argument('--query-encoder-path', type=str, default='nvidia/dragon-multiturn-query-encoder')
9
+ parser.add_argument('--context-encoder-path', type=str, default='nvidia/dragon-multiturn-context-encoder')
10
+
11
+ parser.add_argument('--data-folder', type=str, default='', help='path to the datafolder of ChatRAG Bench')
12
+ parser.add_argument('--eval-dataset', type=str, default='', help='evaluation dataset (e.g., doc2dial)')
13
+
14
+ parser.add_argument('--doc2dial-datapath', type=str, default='doc2dial/test.json')
15
+ parser.add_argument('--doc2dial-docpath', type=str, default='doc2dial/documents.json')
16
+
17
+ parser.add_argument('--quac-datapath', type=str, default='quac/test.json')
18
+ parser.add_argument('--quac-docpath', type=str, default='quac/documents.json')
19
+
20
+ parser.add_argument('--qrecc-datapath', type=str, default='qrecc/test.json')
21
+ parser.add_argument('--qrecc-docpath', type=str, default='qrecc/documents.json')
22
+
23
+ parser.add_argument('--topiocqa-datapath', type=str, default='topiocqa/dev.json')
24
+ parser.add_argument('--topiocqa-docpath', type=str, default='')
25
+
26
+ parser.add_argument('--inscit-datapath', type=str, default='inscit/dev.json')
27
+ parser.add_argument('--inscit-docpath', type=str, default='')
28
+
29
+ args = parser.parse_args()
30
+
31
+ return args
evaluation/dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+
4
+
5
+ def get_query(messages, num_turns=5):
6
+ ## convert query into a format as follows:
7
+ ## user: {user}\nagent: {agent}\nuser: {user}
8
+ query = ""
9
+ for item in messages[-num_turns:]:
10
+ item['role'] = item['role'].replace("assistant", "agent")
11
+ query += "{}: {}\n".format(item['role'], item['content'])
12
+ query = query.strip()
13
+
14
+ return query
15
+
16
+
17
+ def get_query_with_topic(messages, topic, num_turns=3):
18
+ ## convert query into a format as follows:
19
+ ## user: this is a question about {topic}. {user}\nagent: {agent}\nuser: this is a question about {topic}. {user}
20
+ query = ""
21
+ for item in messages[-num_turns:]:
22
+ item['role'] = item['role'].replace("assistant", "agent")
23
+ if item['role'] == 'user':
24
+ query += "{}: this is a question about {}. {}\n".format(item['role'], topic, item['content'])
25
+ else:
26
+ query += "{}: {}\n".format(item['role'], item['content'])
27
+ query = query.strip()
28
+
29
+ return query
30
+
31
+
32
+ def get_data_for_evaluation(input_datapath, document_datapath, dataset_name):
33
+
34
+ print('reading evaluation data from %s' % input_datapath)
35
+ with open(input_datapath, "r") as f:
36
+ input_list = json.load(f)
37
+
38
+ print('reading documents from %s' % document_datapath)
39
+ with open(document_datapath, "r") as f:
40
+ documents = json.load(f)
41
+
42
+ eval_data = {}
43
+ for item in input_list:
44
+ """
45
+ We incorporate topic information for topiocqa and inscit datasets:
46
+ query = get_query_with_topic(item['messages'], item['topic'])
47
+ """
48
+ query = get_query(item['messages'])
49
+
50
+ doc_id = item['document']
51
+ gold_idx = item['ground_truth_ctx']['index']
52
+
53
+ if dataset_name == 'qrecc':
54
+ """
55
+ The 'gold context' for the qrecc dataset is obtained based on the word
56
+ overlaps between gold answer and each context in the document, which might
57
+ not be the real gold context.
58
+
59
+ To improve the evaluation quality of this dataset,
60
+ we further add the answer of the query into the 'gold context'
61
+ to ensure the 'gold context' is the most relevant chunk to the query.
62
+
63
+ Note that this is just for the retrieval evaluation purpose, we do not
64
+ add answer to the context for the ChatRAG evaluation.
65
+ """
66
+ answer = item['answers'][0]
67
+ documents[doc_id][gold_idx] += " || " + answer
68
+
69
+ if doc_id not in eval_data:
70
+ eval_data[doc_id] = [{"query": query, "gold_idx": gold_idx}]
71
+ else:
72
+ eval_data[doc_id].append({"query": query, "gold_idx": gold_idx})
73
+
74
+ return eval_data, documents
evaluation/evaluate.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import AutoModel, AutoTokenizer
3
+ from dataset import get_data_for_evaluation
4
+ from arguments import get_args
5
+ from tqdm import tqdm
6
+ import torch
7
+ import os
8
+
9
+
10
+ def run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer, max_seq_len=512):
11
+
12
+ ranked_indices_list = []
13
+ gold_index_list = []
14
+ for doc_id in tqdm(eval_data):
15
+ context_list = documents[doc_id]
16
+
17
+ with torch.no_grad():
18
+ # get chunk embeddings
19
+ context_embs = []
20
+ for chunk in context_list:
21
+ chunk_ids = tokenizer(chunk, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")
22
+
23
+ c_emb = context_encoder(input_ids=chunk_ids.input_ids, attention_mask=chunk_ids.attention_mask)
24
+ c_emb = c_emb.last_hidden_state[:, 0, :]
25
+ context_embs.append(c_emb)
26
+ context_embs = torch.cat(context_embs, dim=0) # (num_chunk, hidden_dim)
27
+
28
+ sample_list = eval_data[doc_id]
29
+ query_embs = []
30
+ for item in sample_list:
31
+ gold_idx = item['gold_idx']
32
+ gold_index_list.append(gold_idx)
33
+
34
+ query = item['query']
35
+ query_ids = tokenizer(query, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")
36
+ q_emb = query_encoder(input_ids=query_ids.input_ids, attention_mask=query_ids.attention_mask)
37
+ q_emb = q_emb.last_hidden_state[:, 0, :]
38
+ query_embs.append(q_emb)
39
+ query_embs = torch.cat(query_embs, dim=0) # (num_query, hidden_dim)
40
+
41
+ similarities = query_embs.matmul(context_embs.transpose(0,1)) # (num_query, num_chunk)
42
+ ranked_results = torch.argsort(similarities, dim=-1, descending=True) # (num_query, num_chunk)
43
+ ranked_indices_list.extend(ranked_results.tolist())
44
+
45
+ return ranked_indices_list, gold_index_list
46
+
47
+
48
+ def calculate_recall(ranked_indices_list, gold_index_list, topk):
49
+ hit = 0
50
+ for ranked_indices, gold_index in zip(ranked_indices_list, gold_index_list):
51
+ for idx in ranked_indices[:topk]:
52
+ if idx == gold_index:
53
+ hit += 1
54
+ break
55
+ recall = hit / len(ranked_indices_list)
56
+
57
+ print("top-%d recall score: %.4f" % (topk, recall))
58
+
59
+
60
+ def main():
61
+ args = get_args()
62
+
63
+ ## get tokenizer
64
+ tokenizer = AutoTokenizer.from_pretrained(args.query_encoder_path)
65
+
66
+ ## get retriever model
67
+ query_encoder = AutoModel.from_pretrained(args.query_encoder_path)
68
+ context_encoder = AutoModel.from_pretrained(args.context_encoder_path)
69
+ query_encoder.to("cuda"), query_encoder.eval()
70
+ context_encoder.to("cuda"), context_encoder.eval()
71
+
72
+ ## get evaluation data
73
+ if args.eval_dataset == "doc2dial":
74
+ input_datapath = os.path.join(args.data_folder, args.doc2dial_datapath)
75
+ input_docpath = os.path.join(args.data_folder, args.doc2dial_docpath)
76
+ elif args.eval_dataset == "quac":
77
+ input_datapath = os.path.join(args.data_folder, args.quac_datapath)
78
+ input_docpath = os.path.join(args.data_folder, args.quac_docpath)
79
+ elif args.eval_dataset == "qrecc":
80
+ input_datapath = os.path.join(args.data_folder, args.qrecc_datapath)
81
+ input_docpath = os.path.join(args.data_folder, args.qrecc_docpath)
82
+ elif args.eval_dataset == "topiocqa" or args.eval_dataset == "inscit":
83
+ raise Exception("We have prepare the function to get queries, but a wikipedia corpus needs to be downloaded")
84
+ else:
85
+ raise Exception("Please input a correct eval_dataset name!")
86
+
87
+ eval_data, documents = get_data_for_evaluation(input_datapath, input_docpath, args.eval_dataset)
88
+
89
+ ## run retrieval
90
+ ranked_indices_list, gold_index_list = run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer)
91
+ print("number of the total test samples: %d" % len(ranked_indices_list))
92
+
93
+ ## calculate recall scores
94
+ print("evaluating on %s" % args.eval_dataset)
95
+ topk_list = [1, 5, 20]
96
+ for topk in topk_list:
97
+ calculate_recall(ranked_indices_list, gold_index_list, topk=topk)
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
102
+