Upload 4 files
Browse files- evaluation/README.md +8 -0
- evaluation/arguments.py +31 -0
- evaluation/dataset.py +74 -0
- evaluation/evaluate.py +102 -0
evaluation/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
### Commands for running evaluation
|
3 |
+
|
4 |
+
```console
|
5 |
+
python evaluate.py --eval-dataset doc2dial
|
6 |
+
python evaluate.py --eval-dataset quac
|
7 |
+
python evaluate.py --eval-dataset qrecc
|
8 |
+
```
|
evaluation/arguments.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
|
5 |
+
def get_args():
|
6 |
+
parser = argparse.ArgumentParser(description="Dragon-multiturn")
|
7 |
+
|
8 |
+
parser.add_argument('--query-encoder-path', type=str, default='nvidia/dragon-multiturn-query-encoder')
|
9 |
+
parser.add_argument('--context-encoder-path', type=str, default='nvidia/dragon-multiturn-context-encoder')
|
10 |
+
|
11 |
+
parser.add_argument('--data-folder', type=str, default='', help='path to the datafolder of ChatRAG Bench')
|
12 |
+
parser.add_argument('--eval-dataset', type=str, default='', help='evaluation dataset (e.g., doc2dial)')
|
13 |
+
|
14 |
+
parser.add_argument('--doc2dial-datapath', type=str, default='doc2dial/test.json')
|
15 |
+
parser.add_argument('--doc2dial-docpath', type=str, default='doc2dial/documents.json')
|
16 |
+
|
17 |
+
parser.add_argument('--quac-datapath', type=str, default='quac/test.json')
|
18 |
+
parser.add_argument('--quac-docpath', type=str, default='quac/documents.json')
|
19 |
+
|
20 |
+
parser.add_argument('--qrecc-datapath', type=str, default='qrecc/test.json')
|
21 |
+
parser.add_argument('--qrecc-docpath', type=str, default='qrecc/documents.json')
|
22 |
+
|
23 |
+
parser.add_argument('--topiocqa-datapath', type=str, default='topiocqa/dev.json')
|
24 |
+
parser.add_argument('--topiocqa-docpath', type=str, default='')
|
25 |
+
|
26 |
+
parser.add_argument('--inscit-datapath', type=str, default='inscit/dev.json')
|
27 |
+
parser.add_argument('--inscit-docpath', type=str, default='')
|
28 |
+
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
return args
|
evaluation/dataset.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
def get_query(messages, num_turns=5):
|
6 |
+
## convert query into a format as follows:
|
7 |
+
## user: {user}\nagent: {agent}\nuser: {user}
|
8 |
+
query = ""
|
9 |
+
for item in messages[-num_turns:]:
|
10 |
+
item['role'] = item['role'].replace("assistant", "agent")
|
11 |
+
query += "{}: {}\n".format(item['role'], item['content'])
|
12 |
+
query = query.strip()
|
13 |
+
|
14 |
+
return query
|
15 |
+
|
16 |
+
|
17 |
+
def get_query_with_topic(messages, topic, num_turns=3):
|
18 |
+
## convert query into a format as follows:
|
19 |
+
## user: this is a question about {topic}. {user}\nagent: {agent}\nuser: this is a question about {topic}. {user}
|
20 |
+
query = ""
|
21 |
+
for item in messages[-num_turns:]:
|
22 |
+
item['role'] = item['role'].replace("assistant", "agent")
|
23 |
+
if item['role'] == 'user':
|
24 |
+
query += "{}: this is a question about {}. {}\n".format(item['role'], topic, item['content'])
|
25 |
+
else:
|
26 |
+
query += "{}: {}\n".format(item['role'], item['content'])
|
27 |
+
query = query.strip()
|
28 |
+
|
29 |
+
return query
|
30 |
+
|
31 |
+
|
32 |
+
def get_data_for_evaluation(input_datapath, document_datapath, dataset_name):
|
33 |
+
|
34 |
+
print('reading evaluation data from %s' % input_datapath)
|
35 |
+
with open(input_datapath, "r") as f:
|
36 |
+
input_list = json.load(f)
|
37 |
+
|
38 |
+
print('reading documents from %s' % document_datapath)
|
39 |
+
with open(document_datapath, "r") as f:
|
40 |
+
documents = json.load(f)
|
41 |
+
|
42 |
+
eval_data = {}
|
43 |
+
for item in input_list:
|
44 |
+
"""
|
45 |
+
We incorporate topic information for topiocqa and inscit datasets:
|
46 |
+
query = get_query_with_topic(item['messages'], item['topic'])
|
47 |
+
"""
|
48 |
+
query = get_query(item['messages'])
|
49 |
+
|
50 |
+
doc_id = item['document']
|
51 |
+
gold_idx = item['ground_truth_ctx']['index']
|
52 |
+
|
53 |
+
if dataset_name == 'qrecc':
|
54 |
+
"""
|
55 |
+
The 'gold context' for the qrecc dataset is obtained based on the word
|
56 |
+
overlaps between gold answer and each context in the document, which might
|
57 |
+
not be the real gold context.
|
58 |
+
|
59 |
+
To improve the evaluation quality of this dataset,
|
60 |
+
we further add the answer of the query into the 'gold context'
|
61 |
+
to ensure the 'gold context' is the most relevant chunk to the query.
|
62 |
+
|
63 |
+
Note that this is just for the retrieval evaluation purpose, we do not
|
64 |
+
add answer to the context for the ChatRAG evaluation.
|
65 |
+
"""
|
66 |
+
answer = item['answers'][0]
|
67 |
+
documents[doc_id][gold_idx] += " || " + answer
|
68 |
+
|
69 |
+
if doc_id not in eval_data:
|
70 |
+
eval_data[doc_id] = [{"query": query, "gold_idx": gold_idx}]
|
71 |
+
else:
|
72 |
+
eval_data[doc_id].append({"query": query, "gold_idx": gold_idx})
|
73 |
+
|
74 |
+
return eval_data, documents
|
evaluation/evaluate.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import AutoModel, AutoTokenizer
|
3 |
+
from dataset import get_data_for_evaluation
|
4 |
+
from arguments import get_args
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
def run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer, max_seq_len=512):
|
11 |
+
|
12 |
+
ranked_indices_list = []
|
13 |
+
gold_index_list = []
|
14 |
+
for doc_id in tqdm(eval_data):
|
15 |
+
context_list = documents[doc_id]
|
16 |
+
|
17 |
+
with torch.no_grad():
|
18 |
+
# get chunk embeddings
|
19 |
+
context_embs = []
|
20 |
+
for chunk in context_list:
|
21 |
+
chunk_ids = tokenizer(chunk, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")
|
22 |
+
|
23 |
+
c_emb = context_encoder(input_ids=chunk_ids.input_ids, attention_mask=chunk_ids.attention_mask)
|
24 |
+
c_emb = c_emb.last_hidden_state[:, 0, :]
|
25 |
+
context_embs.append(c_emb)
|
26 |
+
context_embs = torch.cat(context_embs, dim=0) # (num_chunk, hidden_dim)
|
27 |
+
|
28 |
+
sample_list = eval_data[doc_id]
|
29 |
+
query_embs = []
|
30 |
+
for item in sample_list:
|
31 |
+
gold_idx = item['gold_idx']
|
32 |
+
gold_index_list.append(gold_idx)
|
33 |
+
|
34 |
+
query = item['query']
|
35 |
+
query_ids = tokenizer(query, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda")
|
36 |
+
q_emb = query_encoder(input_ids=query_ids.input_ids, attention_mask=query_ids.attention_mask)
|
37 |
+
q_emb = q_emb.last_hidden_state[:, 0, :]
|
38 |
+
query_embs.append(q_emb)
|
39 |
+
query_embs = torch.cat(query_embs, dim=0) # (num_query, hidden_dim)
|
40 |
+
|
41 |
+
similarities = query_embs.matmul(context_embs.transpose(0,1)) # (num_query, num_chunk)
|
42 |
+
ranked_results = torch.argsort(similarities, dim=-1, descending=True) # (num_query, num_chunk)
|
43 |
+
ranked_indices_list.extend(ranked_results.tolist())
|
44 |
+
|
45 |
+
return ranked_indices_list, gold_index_list
|
46 |
+
|
47 |
+
|
48 |
+
def calculate_recall(ranked_indices_list, gold_index_list, topk):
|
49 |
+
hit = 0
|
50 |
+
for ranked_indices, gold_index in zip(ranked_indices_list, gold_index_list):
|
51 |
+
for idx in ranked_indices[:topk]:
|
52 |
+
if idx == gold_index:
|
53 |
+
hit += 1
|
54 |
+
break
|
55 |
+
recall = hit / len(ranked_indices_list)
|
56 |
+
|
57 |
+
print("top-%d recall score: %.4f" % (topk, recall))
|
58 |
+
|
59 |
+
|
60 |
+
def main():
|
61 |
+
args = get_args()
|
62 |
+
|
63 |
+
## get tokenizer
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained(args.query_encoder_path)
|
65 |
+
|
66 |
+
## get retriever model
|
67 |
+
query_encoder = AutoModel.from_pretrained(args.query_encoder_path)
|
68 |
+
context_encoder = AutoModel.from_pretrained(args.context_encoder_path)
|
69 |
+
query_encoder.to("cuda"), query_encoder.eval()
|
70 |
+
context_encoder.to("cuda"), context_encoder.eval()
|
71 |
+
|
72 |
+
## get evaluation data
|
73 |
+
if args.eval_dataset == "doc2dial":
|
74 |
+
input_datapath = os.path.join(args.data_folder, args.doc2dial_datapath)
|
75 |
+
input_docpath = os.path.join(args.data_folder, args.doc2dial_docpath)
|
76 |
+
elif args.eval_dataset == "quac":
|
77 |
+
input_datapath = os.path.join(args.data_folder, args.quac_datapath)
|
78 |
+
input_docpath = os.path.join(args.data_folder, args.quac_docpath)
|
79 |
+
elif args.eval_dataset == "qrecc":
|
80 |
+
input_datapath = os.path.join(args.data_folder, args.qrecc_datapath)
|
81 |
+
input_docpath = os.path.join(args.data_folder, args.qrecc_docpath)
|
82 |
+
elif args.eval_dataset == "topiocqa" or args.eval_dataset == "inscit":
|
83 |
+
raise Exception("We have prepare the function to get queries, but a wikipedia corpus needs to be downloaded")
|
84 |
+
else:
|
85 |
+
raise Exception("Please input a correct eval_dataset name!")
|
86 |
+
|
87 |
+
eval_data, documents = get_data_for_evaluation(input_datapath, input_docpath, args.eval_dataset)
|
88 |
+
|
89 |
+
## run retrieval
|
90 |
+
ranked_indices_list, gold_index_list = run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer)
|
91 |
+
print("number of the total test samples: %d" % len(ranked_indices_list))
|
92 |
+
|
93 |
+
## calculate recall scores
|
94 |
+
print("evaluating on %s" % args.eval_dataset)
|
95 |
+
topk_list = [1, 5, 20]
|
96 |
+
for topk in topk_list:
|
97 |
+
calculate_recall(ranked_indices_list, gold_index_list, topk=topk)
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
main()
|
102 |
+
|