Upload training & evaluation script
Browse files- evaluate_trivia_qa.py +285 -0
- training_trivia_qa_bce.py +172 -0
evaluate_trivia_qa.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from datasets import load_dataset
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
7 |
+
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
|
8 |
+
from sentence_transformers.util import mine_hard_negatives
|
9 |
+
|
10 |
+
# Set the log level to INFO to get more information
|
11 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
eval_batch_size = 16
|
16 |
+
|
17 |
+
# 2. Load the trivia-qa dataset: https://huggingface.co/datasets/tomaarsen/trivia-qa-reranker-blogpost-datasets
|
18 |
+
logging.info("Read the trivia-qa reranking dataset")
|
19 |
+
full_dataset = load_dataset("sentence-transformers/trivia-qa", split="train")
|
20 |
+
eval_dataset = full_dataset.select(range(1000)) # Use the first 1000 samples for evaluation
|
21 |
+
logging.info(eval_dataset)
|
22 |
+
|
23 |
+
# 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
|
24 |
+
# We include the positive answer in the list of negatives, so the evaluator can use the performance of the
|
25 |
+
# embedding model as a baseline.
|
26 |
+
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
|
27 |
+
hard_eval_dataset = mine_hard_negatives(
|
28 |
+
eval_dataset,
|
29 |
+
embedding_model,
|
30 |
+
corpus=full_dataset["answer"], # Use the full dataset as the corpus
|
31 |
+
num_negatives=30, # How many documents to rerank
|
32 |
+
batch_size=4096,
|
33 |
+
include_positives=True,
|
34 |
+
output_format="n-tuple",
|
35 |
+
use_faiss=True,
|
36 |
+
)
|
37 |
+
logging.info(hard_eval_dataset)
|
38 |
+
|
39 |
+
# 4. Create reranking evaluators. We use `always_rerank_positives=False` for a realistic evaluation
|
40 |
+
# where only all top 30 documents are reranked, and `always_rerank_positives=True` for an evaluation
|
41 |
+
# where the positive answer is always reranked as well.
|
42 |
+
samples = [
|
43 |
+
{
|
44 |
+
"query": sample["query"],
|
45 |
+
"positive": [sample["answer"]],
|
46 |
+
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
|
47 |
+
}
|
48 |
+
for sample in hard_eval_dataset
|
49 |
+
]
|
50 |
+
realistic_reranking_evaluator = CrossEncoderRerankingEvaluator(
|
51 |
+
samples=samples,
|
52 |
+
batch_size=eval_batch_size,
|
53 |
+
name="trivia-qa-dev-realistic",
|
54 |
+
always_rerank_positives=False,
|
55 |
+
show_progress_bar=True,
|
56 |
+
)
|
57 |
+
evaluation_reranking_evaluator = CrossEncoderRerankingEvaluator(
|
58 |
+
samples=samples,
|
59 |
+
batch_size=eval_batch_size,
|
60 |
+
name="trivia-qa-dev-evaluation",
|
61 |
+
always_rerank_positives=True,
|
62 |
+
show_progress_bar=True,
|
63 |
+
)
|
64 |
+
|
65 |
+
for model_name in [
|
66 |
+
"tomaarsen/reranker-ModernBERT-base-trivia-qa-bce",
|
67 |
+
"cross-encoder/ms-marco-MiniLM-L6-v2",
|
68 |
+
"jinaai/jina-reranker-v1-tiny-en",
|
69 |
+
"jinaai/jina-reranker-v1-turbo-en",
|
70 |
+
"jinaai/jina-reranker-v2-base-multilingual",
|
71 |
+
"BAAI/bge-reranker-base",
|
72 |
+
"BAAI/bge-reranker-large",
|
73 |
+
"BAAI/bge-reranker-v2-m3",
|
74 |
+
"mixedbread-ai/mxbai-rerank-xsmall-v1",
|
75 |
+
"mixedbread-ai/mxbai-rerank-base-v1",
|
76 |
+
"mixedbread-ai/mxbai-rerank-large-v1",
|
77 |
+
# "mixedbread-ai/mxbai-rerank-base-v2",
|
78 |
+
# "mixedbread-ai/mxbai-rerank-large-v2",
|
79 |
+
"Alibaba-NLP/gte-reranker-modernbert-base",
|
80 |
+
]:
|
81 |
+
# 1. Load the model
|
82 |
+
logging.info(f"Loading {model_name} model")
|
83 |
+
# jina models need max_length=1024, bge-reranker-base and -large need max_length=512
|
84 |
+
cross_encoder = CrossEncoder(model_name, model_kwargs={"torch_dtype": torch.bfloat16}, trust_remote_code=True)
|
85 |
+
|
86 |
+
# 2. Evaluate the model on the reranking dataset
|
87 |
+
logging.info(f"Evaluating {model_name}")
|
88 |
+
print(model_name)
|
89 |
+
print(realistic_reranking_evaluator(cross_encoder))
|
90 |
+
print(evaluation_reranking_evaluator(cross_encoder))
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
main()
|
94 |
+
|
95 |
+
"""
|
96 |
+
2025-03-28 10:17:20 - Evaluating tomaarsen/reranker-ModernBERT-base-trivia-qa-bce
|
97 |
+
2025-03-28 10:17:20 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
|
98 |
+
2025-03-28 10:21:06 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
99 |
+
2025-03-28 10:21:06 - Base -> Reranked
|
100 |
+
2025-03-28 10:21:06 - MAP: 46.91 -> 65.07
|
101 |
+
2025-03-28 10:21:06 - MRR@10: 46.19 -> 65.33
|
102 |
+
2025-03-28 10:21:06 - NDCG@10: 52.31 -> 69.21
|
103 |
+
2025-03-28 10:21:06 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
|
104 |
+
2025-03-28 10:25:39 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
105 |
+
2025-03-28 10:25:39 - Base -> Reranked
|
106 |
+
2025-03-28 10:25:39 - MAP: 46.91 -> 76.18
|
107 |
+
2025-03-28 10:25:39 - MRR@10: 46.19 -> 76.46
|
108 |
+
2025-03-28 10:25:39 - NDCG@10: 52.31 -> 81.16
|
109 |
+
|
110 |
+
2025-03-28 10:25:39 - Loading cross-encoder/ms-marco-MiniLM-L6-v2 model
|
111 |
+
2025-03-28 10:25:40 - Use pytorch device: cuda
|
112 |
+
2025-03-28 10:25:40 - Evaluating cross-encoder/ms-marco-MiniLM-L6-v2
|
113 |
+
2025-03-28 10:25:40 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
|
114 |
+
2025-03-28 10:26:08 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
115 |
+
2025-03-28 10:26:08 - Base -> Reranked
|
116 |
+
2025-03-28 10:26:08 - MAP: 46.91 -> 59.97
|
117 |
+
2025-03-28 10:26:08 - MRR@10: 46.19 -> 59.72
|
118 |
+
2025-03-28 10:26:08 - NDCG@10: 52.31 -> 64.26
|
119 |
+
2025-03-28 10:26:08 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
|
120 |
+
2025-03-28 10:26:41 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
121 |
+
2025-03-28 10:26:41 - Base -> Reranked
|
122 |
+
2025-03-28 10:26:41 - MAP: 46.91 -> 65.99
|
123 |
+
2025-03-28 10:26:41 - MRR@10: 46.19 -> 65.41
|
124 |
+
2025-03-28 10:26:41 - NDCG@10: 52.31 -> 70.82
|
125 |
+
|
126 |
+
2025-03-28 10:26:41 - Loading jinaai/jina-reranker-v1-tiny-en model
|
127 |
+
2025-03-28 10:26:43 - Use pytorch device: cuda
|
128 |
+
2025-03-28 10:26:44 - Evaluating jinaai/jina-reranker-v1-tiny-en
|
129 |
+
2025-03-28 10:26:44 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
|
130 |
+
2025-03-28 10:28:49 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
131 |
+
2025-03-28 10:28:49 - Base -> Reranked
|
132 |
+
2025-03-28 10:28:49 - MAP: 46.91 -> 59.46
|
133 |
+
2025-03-28 10:28:49 - MRR@10: 46.19 -> 59.57
|
134 |
+
2025-03-28 10:28:49 - NDCG@10: 52.31 -> 64.11
|
135 |
+
2025-03-28 10:28:49 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
|
136 |
+
2025-03-28 10:30:53 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
137 |
+
2025-03-28 10:30:53 - Base -> Reranked
|
138 |
+
2025-03-28 10:30:53 - MAP: 46.91 -> 65.14
|
139 |
+
2025-03-28 10:30:53 - MRR@10: 46.19 -> 65.69
|
140 |
+
2025-03-28 10:30:53 - NDCG@10: 52.31 -> 70.47
|
141 |
+
|
142 |
+
2025-03-28 10:33:24 - Evaluating jinaai/jina-reranker-v1-turbo-en
|
143 |
+
2025-03-28 10:33:24 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
|
144 |
+
2025-03-28 10:36:16 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
145 |
+
2025-03-28 10:36:16 - Base -> Reranked
|
146 |
+
2025-03-28 10:36:16 - MAP: 46.91 -> 59.21
|
147 |
+
2025-03-28 10:36:16 - MRR@10: 46.19 -> 59.03
|
148 |
+
2025-03-28 10:36:16 - NDCG@10: 52.31 -> 63.75
|
149 |
+
2025-03-28 10:36:16 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
|
150 |
+
2025-03-28 10:39:09 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
151 |
+
2025-03-28 10:39:09 - Base -> Reranked
|
152 |
+
2025-03-28 10:39:09 - MAP: 46.91 -> 65.39
|
153 |
+
2025-03-28 10:39:09 - MRR@10: 46.19 -> 65.04
|
154 |
+
2025-03-28 10:39:09 - NDCG@10: 52.31 -> 70.70
|
155 |
+
|
156 |
+
2025-03-28 10:54:00 - Evaluating jinaai/jina-reranker-v2-base-multilingual
|
157 |
+
2025-03-28 10:54:00 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
|
158 |
+
2025-03-28 10:56:49 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
159 |
+
2025-03-28 10:56:49 - Base -> Reranked
|
160 |
+
2025-03-28 10:56:49 - MAP: 46.91 -> 61.59
|
161 |
+
2025-03-28 10:56:49 - MRR@10: 46.19 -> 61.63
|
162 |
+
2025-03-28 10:56:49 - NDCG@10: 52.31 -> 66.16
|
163 |
+
2025-03-28 10:56:49 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
|
164 |
+
2025-03-28 11:00:14 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
165 |
+
2025-03-28 11:00:14 - Base -> Reranked
|
166 |
+
2025-03-28 11:00:14 - MAP: 46.91 -> 69.46
|
167 |
+
2025-03-28 11:00:14 - MRR@10: 46.19 -> 69.47
|
168 |
+
2025-03-28 11:00:14 - NDCG@10: 52.31 -> 74.74
|
169 |
+
|
170 |
+
2025-03-28 11:14:39 - Evaluating BAAI/bge-reranker-base
|
171 |
+
BAAI/bge-reranker-base
|
172 |
+
2025-03-28 11:14:39 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
|
173 |
+
2025-03-28 11:15:42 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
174 |
+
2025-03-28 11:15:42 - Base -> Reranked
|
175 |
+
2025-03-28 11:15:42 - MAP: 46.91 -> 52.59
|
176 |
+
2025-03-28 11:15:42 - MRR@10: 46.19 -> 54.37
|
177 |
+
2025-03-28 11:15:42 - NDCG@10: 52.31 -> 59.41
|
178 |
+
{'trivia-qa-dev-realistic_map': 0.5259323270998224, 'trivia-qa-dev-realistic_mrr@10': 0.5436916666666666, 'trivia-qa-dev-realistic_ndcg@10': 0.5940858318692244, 'trivia-qa-dev-realistic_base_map': 0.46913299504084743, 'trivia-qa-dev-realistic_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-realistic_base_ndcg@10': 0.5231399731095658}
|
179 |
+
2025-03-28 11:15:42 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
|
180 |
+
2025-03-28 11:16:58 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
181 |
+
2025-03-28 11:16:58 - Base -> Reranked
|
182 |
+
2025-03-28 11:16:58 - MAP: 46.91 -> 59.04
|
183 |
+
2025-03-28 11:16:58 - MRR@10: 46.19 -> 62.72
|
184 |
+
2025-03-28 11:16:58 - NDCG@10: 52.31 -> 66.61
|
185 |
+
{'trivia-qa-dev-evaluation_map': 0.590369907985238, 'trivia-qa-dev-evaluation_mrr@10': 0.6271603174603174, 'trivia-qa-dev-evaluation_ndcg@10': 0.6661314116591164, 'trivia-qa-dev-evaluation_base_map': 0.46913299504084743, 'trivia-qa-dev-evaluation_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-evaluation_base_ndcg@10': 0.5231399731095658}
|
186 |
+
|
187 |
+
2025-03-28 11:17:01 - Evaluating BAAI/bge-reranker-large
|
188 |
+
BAAI/bge-reranker-large
|
189 |
+
2025-03-28 11:17:01 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
|
190 |
+
2025-03-28 11:19:48 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
191 |
+
2025-03-28 11:19:48 - Base -> Reranked
|
192 |
+
2025-03-28 11:19:48 - MAP: 46.91 -> 55.68
|
193 |
+
2025-03-28 11:19:48 - MRR@10: 46.19 -> 56.87
|
194 |
+
2025-03-28 11:19:48 - NDCG@10: 52.31 -> 61.76
|
195 |
+
{'trivia-qa-dev-realistic_map': 0.5567810015278374, 'trivia-qa-dev-realistic_mrr@10': 0.5687162698412699, 'trivia-qa-dev-realistic_ndcg@10': 0.6176060985342933, 'trivia-qa-dev-realistic_base_map': 0.46913299504084743, 'trivia-qa-dev-realistic_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-realistic_base_ndcg@10': 0.5231399731095658}
|
196 |
+
2025-03-28 11:19:48 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
|
197 |
+
2025-03-28 11:23:10 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
198 |
+
2025-03-28 11:23:10 - Base -> Reranked
|
199 |
+
2025-03-28 11:23:10 - MAP: 46.91 -> 62.93
|
200 |
+
2025-03-28 11:23:10 - MRR@10: 46.19 -> 65.24
|
201 |
+
2025-03-28 11:23:10 - NDCG@10: 52.31 -> 69.92
|
202 |
+
{'trivia-qa-dev-evaluation_map': 0.6292598276500135, 'trivia-qa-dev-evaluation_mrr@10': 0.6523801587301588, 'trivia-qa-dev-evaluation_ndcg@10': 0.6992497211715496, 'trivia-qa-dev-evaluation_base_map': 0.46913299504084743, 'trivia-qa-dev-evaluation_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-evaluation_base_ndcg@10': 0.5231399731095658}
|
203 |
+
|
204 |
+
bge-reranker-v2-m3
|
205 |
+
2025-03-28 11:33:42 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
206 |
+
2025-03-28 11:33:42 - Base -> Reranked
|
207 |
+
2025-03-28 11:33:42 - MAP: 46.91 -> 59.46
|
208 |
+
2025-03-28 11:33:42 - MRR@10: 46.19 -> 60.49
|
209 |
+
2025-03-28 11:33:42 - NDCG@10: 52.31 -> 64.85
|
210 |
+
{'trivia-qa-dev-realistic_map': 0.5945974714456489, 'trivia-qa-dev-realistic_mrr@10': 0.6049440476190477, 'trivia-qa-dev-realistic_ndcg@10': 0.6485089522801432, 'trivia-qa-dev-realistic_base_map': 0.46913299504084743, 'trivia-qa-dev-realistic_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-realistic_base_ndcg@10': 0.523139973109566}
|
211 |
+
2025-03-28 11:33:42 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-evaluation dataset:
|
212 |
+
2025-03-28 11:41:37 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
213 |
+
2025-03-28 11:41:37 - Base -> Reranked
|
214 |
+
2025-03-28 11:41:37 - MAP: 46.91 -> 67.12
|
215 |
+
2025-03-28 11:41:37 - MRR@10: 46.19 -> 68.82
|
216 |
+
2025-03-28 11:41:37 - NDCG@10: 52.31 -> 73.29
|
217 |
+
{'gooaq-dev-evaluation_map': 0.6712088650084718, 'gooaq-dev-evaluation_mrr@10': 0.6881884920634921, 'gooaq-dev-evaluation_ndcg@10': 0.7329280539251892, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
|
218 |
+
|
219 |
+
2025-03-28 11:41:38 - Evaluating mixedbread-ai/mxbai-rerank-xsmall-v1
|
220 |
+
mixedbread-ai/mxbai-rerank-xsmall-v1
|
221 |
+
2025-03-28 11:41:38 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-realistic dataset:
|
222 |
+
2025-03-28 11:42:59 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
223 |
+
2025-03-28 11:42:59 - Base -> Reranked
|
224 |
+
2025-03-28 11:42:59 - MAP: 46.91 -> 58.07
|
225 |
+
2025-03-28 11:42:59 - MRR@10: 46.19 -> 57.79
|
226 |
+
2025-03-28 11:42:59 - NDCG@10: 52.31 -> 62.53
|
227 |
+
{'gooaq-dev-realistic_map': 0.5806904579481978, 'gooaq-dev-realistic_mrr@10': 0.5778678571428572, 'gooaq-dev-realistic_ndcg@10': 0.6252806166705941, 'gooaq-dev-realistic_base_map': 0.46913299504084743, 'gooaq-dev-realistic_base_mrr@10': 0.4618571428571429, 'gooaq-dev-realistic_base_ndcg@10': 0.523139973109566}
|
228 |
+
2025-03-28 11:42:59 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-evaluation dataset:
|
229 |
+
2025-03-28 11:44:37 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
230 |
+
2025-03-28 11:44:37 - Base -> Reranked
|
231 |
+
2025-03-28 11:44:37 - MAP: 46.91 -> 63.42
|
232 |
+
2025-03-28 11:44:37 - MRR@10: 46.19 -> 63.23
|
233 |
+
2025-03-28 11:44:37 - NDCG@10: 52.31 -> 68.38
|
234 |
+
{'gooaq-dev-evaluation_map': 0.6341866885097522, 'gooaq-dev-evaluation_mrr@10': 0.6323373015873017, 'gooaq-dev-evaluation_ndcg@10': 0.683834674484044, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
|
235 |
+
|
236 |
+
2025-03-28 11:44:38 - Evaluating mixedbread-ai/mxbai-rerank-base-v1
|
237 |
+
mixedbread-ai/mxbai-rerank-base-v1
|
238 |
+
2025-03-28 11:44:38 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-realistic dataset:
|
239 |
+
2025-03-28 11:47:02 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
240 |
+
2025-03-28 11:47:02 - Base -> Reranked
|
241 |
+
2025-03-28 11:47:02 - MAP: 46.91 -> 55.81
|
242 |
+
2025-03-28 11:47:02 - MRR@10: 46.19 -> 55.68
|
243 |
+
2025-03-28 11:47:02 - NDCG@10: 52.31 -> 60.99
|
244 |
+
{'gooaq-dev-realistic_map': 0.5580880756399167, 'gooaq-dev-realistic_mrr@10': 0.5567904761904762, 'gooaq-dev-realistic_ndcg@10': 0.6099184869749001, 'gooaq-dev-realistic_base_map': 0.46913299504084743, 'gooaq-dev-realistic_base_mrr@10': 0.4618571428571429, 'gooaq-dev-realistic_base_ndcg@10': 0.523139973109566}
|
245 |
+
2025-03-28 11:47:02 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-evaluation dataset:
|
246 |
+
2025-03-28 11:49:56 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
247 |
+
2025-03-28 11:49:56 - Base -> Reranked
|
248 |
+
2025-03-28 11:49:56 - MAP: 46.91 -> 62.12
|
249 |
+
2025-03-28 11:49:56 - MRR@10: 46.19 -> 61.81
|
250 |
+
2025-03-28 11:49:56 - NDCG@10: 52.31 -> 67.69
|
251 |
+
{'gooaq-dev-evaluation_map': 0.62120330763951, 'gooaq-dev-evaluation_mrr@10': 0.6180714285714286, 'gooaq-dev-evaluation_ndcg@10': 0.6769237801354084, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
|
252 |
+
|
253 |
+
2025-03-28 11:49:57 - Evaluating mixedbread-ai/mxbai-rerank-large-v1
|
254 |
+
mixedbread-ai/mxbai-rerank-large-v1
|
255 |
+
2025-03-28 11:49:57 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-realistic dataset:
|
256 |
+
2025-03-28 11:56:18 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
257 |
+
2025-03-28 11:56:18 - Base -> Reranked
|
258 |
+
2025-03-28 11:56:18 - MAP: 46.91 -> 58.13
|
259 |
+
2025-03-28 11:56:18 - MRR@10: 46.19 -> 58.63
|
260 |
+
2025-03-28 11:56:18 - NDCG@10: 52.31 -> 63.38
|
261 |
+
{'gooaq-dev-realistic_map': 0.5813141616006278, 'gooaq-dev-realistic_mrr@10': 0.5862551587301587, 'gooaq-dev-realistic_ndcg@10': 0.6337779476332643, 'gooaq-dev-realistic_base_map': 0.46913299504084743, 'gooaq-dev-realistic_base_mrr@10': 0.4618571428571429, 'gooaq-dev-realistic_base_ndcg@10': 0.523139973109566}
|
262 |
+
2025-03-28 11:56:18 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-evaluation dataset:
|
263 |
+
2025-03-28 12:03:56 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
264 |
+
2025-03-28 12:03:56 - Base -> Reranked
|
265 |
+
2025-03-28 12:03:56 - MAP: 46.91 -> 64.79
|
266 |
+
2025-03-28 12:03:56 - MRR@10: 46.19 -> 65.64
|
267 |
+
2025-03-28 12:03:56 - NDCG@10: 52.31 -> 70.65
|
268 |
+
{'gooaq-dev-evaluation_map': 0.64787303000053, 'gooaq-dev-evaluation_mrr@10': 0.6563722222222222, 'gooaq-dev-evaluation_ndcg@10': 0.7065480874647052, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
|
269 |
+
|
270 |
+
2025-03-28 12:03:58 - Evaluating Alibaba-NLP/gte-reranker-modernbert-base
|
271 |
+
Alibaba-NLP/gte-reranker-modernbert-base
|
272 |
+
2025-03-28 12:03:58 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-realistic dataset:
|
273 |
+
2025-03-28 12:07:46 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
274 |
+
2025-03-28 12:07:46 - Base -> Reranked
|
275 |
+
2025-03-28 12:07:46 - MAP: 46.91 -> 61.20
|
276 |
+
2025-03-28 12:07:46 - MRR@10: 46.19 -> 61.96
|
277 |
+
2025-03-28 12:07:46 - NDCG@10: 52.31 -> 65.92
|
278 |
+
{'gooaq-dev-realistic_map': 0.6119823635603096, 'gooaq-dev-realistic_mrr@10': 0.6195595238095238, 'gooaq-dev-realistic_ndcg@10': 0.6591716946749766, 'gooaq-dev-realistic_base_map': 0.46913299504084743, 'gooaq-dev-realistic_base_mrr@10': 0.4618571428571429, 'gooaq-dev-realistic_base_ndcg@10': 0.523139973109566}
|
279 |
+
2025-03-28 12:12:23 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
|
280 |
+
2025-03-28 12:12:23 - Base -> Reranked
|
281 |
+
2025-03-28 12:12:23 - MAP: 46.91 -> 69.22
|
282 |
+
2025-03-28 12:12:23 - MRR@10: 46.19 -> 70.84
|
283 |
+
2025-03-28 12:12:23 - NDCG@10: 52.31 -> 74.88
|
284 |
+
{'gooaq-dev-evaluation_map': 0.6921869586576195, 'gooaq-dev-evaluation_mrr@10': 0.7084460317460317, 'gooaq-dev-evaluation_ndcg@10': 0.7487587539162411, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
|
285 |
+
"""
|
training_trivia_qa_bce.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import traceback
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from datasets import load_dataset
|
6 |
+
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderModelCardData
|
9 |
+
from sentence_transformers.cross_encoder.evaluation import (
|
10 |
+
CrossEncoderNanoBEIREvaluator,
|
11 |
+
CrossEncoderRerankingEvaluator,
|
12 |
+
)
|
13 |
+
from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
|
14 |
+
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
|
15 |
+
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
|
16 |
+
from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator
|
17 |
+
from sentence_transformers.util import mine_hard_negatives
|
18 |
+
|
19 |
+
# Set the log level to INFO to get more information
|
20 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
|
21 |
+
|
22 |
+
|
23 |
+
def main():
|
24 |
+
model_name = "answerdotai/ModernBERT-base"
|
25 |
+
|
26 |
+
train_batch_size = 16
|
27 |
+
num_epochs = 1
|
28 |
+
num_hard_negatives = 5 # How many hard negatives should be mined for each question-answer pair
|
29 |
+
|
30 |
+
# 1a. Load a model to finetune with 1b. (Optional) model card data
|
31 |
+
model = CrossEncoder(
|
32 |
+
model_name,
|
33 |
+
model_card_data=CrossEncoderModelCardData(
|
34 |
+
language="en",
|
35 |
+
license="apache-2.0",
|
36 |
+
model_name="ModernBERT-base trained on TriviaQA",
|
37 |
+
),
|
38 |
+
)
|
39 |
+
print("Model max length:", model.max_length)
|
40 |
+
print("Model num labels:", model.num_labels)
|
41 |
+
|
42 |
+
# 2a. Load the TriviaQA dataset: https://huggingface.co/datasets/sentence-transformers/trivia-qa
|
43 |
+
logging.info("Read the trivia-qa training dataset")
|
44 |
+
full_dataset = load_dataset("sentence-transformers/trivia-qa", split="train")
|
45 |
+
eval_dataset = full_dataset.select(range(1000)) # Use the first 1000 samples for evaluation
|
46 |
+
train_dataset = full_dataset.select(range(1000, len(full_dataset))) # Use the rest of the samples for training
|
47 |
+
logging.info(train_dataset)
|
48 |
+
logging.info(eval_dataset)
|
49 |
+
|
50 |
+
# 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
|
51 |
+
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
|
52 |
+
hard_train_dataset = mine_hard_negatives(
|
53 |
+
train_dataset,
|
54 |
+
embedding_model,
|
55 |
+
num_negatives=num_hard_negatives, # How many negatives per question-answer pair
|
56 |
+
margin=0, # Similarity between query and negative samples should be x lower than query-positive similarity
|
57 |
+
range_min=0, # Skip the x most similar samples
|
58 |
+
range_max=100, # Consider only the x most similar samples
|
59 |
+
sampling_strategy="top", # Sample the top negatives from the range
|
60 |
+
batch_size=4096, # Use a batch size of 4096 for the embedding model
|
61 |
+
output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
|
62 |
+
use_faiss=True,
|
63 |
+
)
|
64 |
+
logging.info(hard_train_dataset)
|
65 |
+
|
66 |
+
# 2c. (Optionally) Save the hard training dataset to disk
|
67 |
+
# hard_train_dataset.save_to_disk("trivia-qa-hard-train")
|
68 |
+
# Load again with:
|
69 |
+
# hard_train_dataset = load_from_disk("trivia-qa-hard-train")
|
70 |
+
|
71 |
+
# 3. Define our training loss.
|
72 |
+
# pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
|
73 |
+
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
|
74 |
+
|
75 |
+
# 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
|
76 |
+
nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
|
77 |
+
dataset_names=["msmarco", "nfcorpus", "nq"],
|
78 |
+
batch_size=train_batch_size,
|
79 |
+
)
|
80 |
+
|
81 |
+
# 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
|
82 |
+
# We include the positive answer in the list of negatives, so the evaluator can use the performance of the
|
83 |
+
# embedding model as a baseline.
|
84 |
+
hard_eval_dataset = mine_hard_negatives(
|
85 |
+
eval_dataset,
|
86 |
+
embedding_model,
|
87 |
+
corpus=full_dataset["answer"], # Use the full dataset as the corpus
|
88 |
+
num_negatives=30, # How many documents to rerank
|
89 |
+
batch_size=4096,
|
90 |
+
include_positives=True,
|
91 |
+
output_format="n-tuple",
|
92 |
+
use_faiss=True,
|
93 |
+
)
|
94 |
+
logging.info(hard_eval_dataset)
|
95 |
+
reranking_evaluator = CrossEncoderRerankingEvaluator(
|
96 |
+
samples=[
|
97 |
+
{
|
98 |
+
"query": sample["query"],
|
99 |
+
"positive": [sample["answer"]],
|
100 |
+
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
|
101 |
+
}
|
102 |
+
for sample in hard_eval_dataset
|
103 |
+
],
|
104 |
+
batch_size=train_batch_size,
|
105 |
+
name="trivia-qa-dev",
|
106 |
+
always_rerank_positives=False,
|
107 |
+
)
|
108 |
+
|
109 |
+
# 4c. Combine the evaluators & run the base model on them
|
110 |
+
evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
|
111 |
+
evaluator(model)
|
112 |
+
|
113 |
+
# 5. Define the training arguments
|
114 |
+
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
|
115 |
+
run_name = f"reranker-{short_model_name}-trivia-qa-bce"
|
116 |
+
args = CrossEncoderTrainingArguments(
|
117 |
+
# Required parameter:
|
118 |
+
output_dir=f"models/{run_name}",
|
119 |
+
# Optional training parameters:
|
120 |
+
num_train_epochs=num_epochs,
|
121 |
+
per_device_train_batch_size=train_batch_size,
|
122 |
+
per_device_eval_batch_size=train_batch_size,
|
123 |
+
learning_rate=2e-5,
|
124 |
+
warmup_ratio=0.1,
|
125 |
+
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
|
126 |
+
bf16=True, # Set to True if you have a GPU that supports BF16
|
127 |
+
dataloader_num_workers=4,
|
128 |
+
load_best_model_at_end=True,
|
129 |
+
metric_for_best_model="eval_trivia-qa-dev_ndcg@10",
|
130 |
+
# Optional tracking/debugging parameters:
|
131 |
+
eval_strategy="steps",
|
132 |
+
eval_steps=1000,
|
133 |
+
save_strategy="steps",
|
134 |
+
save_steps=1000,
|
135 |
+
save_total_limit=2,
|
136 |
+
logging_steps=200,
|
137 |
+
logging_first_step=True,
|
138 |
+
run_name=run_name, # Will be used in W&B if `wandb` is installed
|
139 |
+
seed=12,
|
140 |
+
)
|
141 |
+
|
142 |
+
# 6. Create the trainer & start training
|
143 |
+
trainer = CrossEncoderTrainer(
|
144 |
+
model=model,
|
145 |
+
args=args,
|
146 |
+
train_dataset=hard_train_dataset,
|
147 |
+
loss=loss,
|
148 |
+
evaluator=evaluator,
|
149 |
+
)
|
150 |
+
trainer.train()
|
151 |
+
|
152 |
+
# 7. Evaluate the final model, useful to include these in the model card
|
153 |
+
evaluator(model)
|
154 |
+
|
155 |
+
# 8. Save the final model
|
156 |
+
final_output_dir = f"models/{run_name}/final"
|
157 |
+
model.save_pretrained(final_output_dir)
|
158 |
+
|
159 |
+
# 9. (Optional) save the model to the Hugging Face Hub!
|
160 |
+
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
|
161 |
+
try:
|
162 |
+
model.push_to_hub(run_name)
|
163 |
+
except Exception:
|
164 |
+
logging.error(
|
165 |
+
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
|
166 |
+
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
|
167 |
+
f"and saving it using `model.push_to_hub('{run_name}')`."
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
main()
|