tomaarsen HF staff commited on
Commit
4710a86
·
verified ·
1 Parent(s): d7ced1e

Upload training & evaluation script

Browse files
Files changed (2) hide show
  1. evaluate_trivia_qa.py +285 -0
  2. training_trivia_qa_bce.py +172 -0
evaluate_trivia_qa.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from datasets import load_dataset
3
+ import torch
4
+
5
+ from sentence_transformers import SentenceTransformer
6
+ from sentence_transformers.cross_encoder import CrossEncoder
7
+ from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
8
+ from sentence_transformers.util import mine_hard_negatives
9
+
10
+ # Set the log level to INFO to get more information
11
+ logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
12
+
13
+
14
+ def main():
15
+ eval_batch_size = 16
16
+
17
+ # 2. Load the trivia-qa dataset: https://huggingface.co/datasets/tomaarsen/trivia-qa-reranker-blogpost-datasets
18
+ logging.info("Read the trivia-qa reranking dataset")
19
+ full_dataset = load_dataset("sentence-transformers/trivia-qa", split="train")
20
+ eval_dataset = full_dataset.select(range(1000)) # Use the first 1000 samples for evaluation
21
+ logging.info(eval_dataset)
22
+
23
+ # 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
24
+ # We include the positive answer in the list of negatives, so the evaluator can use the performance of the
25
+ # embedding model as a baseline.
26
+ embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
27
+ hard_eval_dataset = mine_hard_negatives(
28
+ eval_dataset,
29
+ embedding_model,
30
+ corpus=full_dataset["answer"], # Use the full dataset as the corpus
31
+ num_negatives=30, # How many documents to rerank
32
+ batch_size=4096,
33
+ include_positives=True,
34
+ output_format="n-tuple",
35
+ use_faiss=True,
36
+ )
37
+ logging.info(hard_eval_dataset)
38
+
39
+ # 4. Create reranking evaluators. We use `always_rerank_positives=False` for a realistic evaluation
40
+ # where only all top 30 documents are reranked, and `always_rerank_positives=True` for an evaluation
41
+ # where the positive answer is always reranked as well.
42
+ samples = [
43
+ {
44
+ "query": sample["query"],
45
+ "positive": [sample["answer"]],
46
+ "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
47
+ }
48
+ for sample in hard_eval_dataset
49
+ ]
50
+ realistic_reranking_evaluator = CrossEncoderRerankingEvaluator(
51
+ samples=samples,
52
+ batch_size=eval_batch_size,
53
+ name="trivia-qa-dev-realistic",
54
+ always_rerank_positives=False,
55
+ show_progress_bar=True,
56
+ )
57
+ evaluation_reranking_evaluator = CrossEncoderRerankingEvaluator(
58
+ samples=samples,
59
+ batch_size=eval_batch_size,
60
+ name="trivia-qa-dev-evaluation",
61
+ always_rerank_positives=True,
62
+ show_progress_bar=True,
63
+ )
64
+
65
+ for model_name in [
66
+ "tomaarsen/reranker-ModernBERT-base-trivia-qa-bce",
67
+ "cross-encoder/ms-marco-MiniLM-L6-v2",
68
+ "jinaai/jina-reranker-v1-tiny-en",
69
+ "jinaai/jina-reranker-v1-turbo-en",
70
+ "jinaai/jina-reranker-v2-base-multilingual",
71
+ "BAAI/bge-reranker-base",
72
+ "BAAI/bge-reranker-large",
73
+ "BAAI/bge-reranker-v2-m3",
74
+ "mixedbread-ai/mxbai-rerank-xsmall-v1",
75
+ "mixedbread-ai/mxbai-rerank-base-v1",
76
+ "mixedbread-ai/mxbai-rerank-large-v1",
77
+ # "mixedbread-ai/mxbai-rerank-base-v2",
78
+ # "mixedbread-ai/mxbai-rerank-large-v2",
79
+ "Alibaba-NLP/gte-reranker-modernbert-base",
80
+ ]:
81
+ # 1. Load the model
82
+ logging.info(f"Loading {model_name} model")
83
+ # jina models need max_length=1024, bge-reranker-base and -large need max_length=512
84
+ cross_encoder = CrossEncoder(model_name, model_kwargs={"torch_dtype": torch.bfloat16}, trust_remote_code=True)
85
+
86
+ # 2. Evaluate the model on the reranking dataset
87
+ logging.info(f"Evaluating {model_name}")
88
+ print(model_name)
89
+ print(realistic_reranking_evaluator(cross_encoder))
90
+ print(evaluation_reranking_evaluator(cross_encoder))
91
+
92
+ if __name__ == "__main__":
93
+ main()
94
+
95
+ """
96
+ 2025-03-28 10:17:20 - Evaluating tomaarsen/reranker-ModernBERT-base-trivia-qa-bce
97
+ 2025-03-28 10:17:20 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
98
+ 2025-03-28 10:21:06 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
99
+ 2025-03-28 10:21:06 - Base -> Reranked
100
+ 2025-03-28 10:21:06 - MAP: 46.91 -> 65.07
101
+ 2025-03-28 10:21:06 - MRR@10: 46.19 -> 65.33
102
+ 2025-03-28 10:21:06 - NDCG@10: 52.31 -> 69.21
103
+ 2025-03-28 10:21:06 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
104
+ 2025-03-28 10:25:39 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
105
+ 2025-03-28 10:25:39 - Base -> Reranked
106
+ 2025-03-28 10:25:39 - MAP: 46.91 -> 76.18
107
+ 2025-03-28 10:25:39 - MRR@10: 46.19 -> 76.46
108
+ 2025-03-28 10:25:39 - NDCG@10: 52.31 -> 81.16
109
+
110
+ 2025-03-28 10:25:39 - Loading cross-encoder/ms-marco-MiniLM-L6-v2 model
111
+ 2025-03-28 10:25:40 - Use pytorch device: cuda
112
+ 2025-03-28 10:25:40 - Evaluating cross-encoder/ms-marco-MiniLM-L6-v2
113
+ 2025-03-28 10:25:40 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
114
+ 2025-03-28 10:26:08 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
115
+ 2025-03-28 10:26:08 - Base -> Reranked
116
+ 2025-03-28 10:26:08 - MAP: 46.91 -> 59.97
117
+ 2025-03-28 10:26:08 - MRR@10: 46.19 -> 59.72
118
+ 2025-03-28 10:26:08 - NDCG@10: 52.31 -> 64.26
119
+ 2025-03-28 10:26:08 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
120
+ 2025-03-28 10:26:41 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
121
+ 2025-03-28 10:26:41 - Base -> Reranked
122
+ 2025-03-28 10:26:41 - MAP: 46.91 -> 65.99
123
+ 2025-03-28 10:26:41 - MRR@10: 46.19 -> 65.41
124
+ 2025-03-28 10:26:41 - NDCG@10: 52.31 -> 70.82
125
+
126
+ 2025-03-28 10:26:41 - Loading jinaai/jina-reranker-v1-tiny-en model
127
+ 2025-03-28 10:26:43 - Use pytorch device: cuda
128
+ 2025-03-28 10:26:44 - Evaluating jinaai/jina-reranker-v1-tiny-en
129
+ 2025-03-28 10:26:44 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
130
+ 2025-03-28 10:28:49 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
131
+ 2025-03-28 10:28:49 - Base -> Reranked
132
+ 2025-03-28 10:28:49 - MAP: 46.91 -> 59.46
133
+ 2025-03-28 10:28:49 - MRR@10: 46.19 -> 59.57
134
+ 2025-03-28 10:28:49 - NDCG@10: 52.31 -> 64.11
135
+ 2025-03-28 10:28:49 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
136
+ 2025-03-28 10:30:53 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
137
+ 2025-03-28 10:30:53 - Base -> Reranked
138
+ 2025-03-28 10:30:53 - MAP: 46.91 -> 65.14
139
+ 2025-03-28 10:30:53 - MRR@10: 46.19 -> 65.69
140
+ 2025-03-28 10:30:53 - NDCG@10: 52.31 -> 70.47
141
+
142
+ 2025-03-28 10:33:24 - Evaluating jinaai/jina-reranker-v1-turbo-en
143
+ 2025-03-28 10:33:24 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
144
+ 2025-03-28 10:36:16 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
145
+ 2025-03-28 10:36:16 - Base -> Reranked
146
+ 2025-03-28 10:36:16 - MAP: 46.91 -> 59.21
147
+ 2025-03-28 10:36:16 - MRR@10: 46.19 -> 59.03
148
+ 2025-03-28 10:36:16 - NDCG@10: 52.31 -> 63.75
149
+ 2025-03-28 10:36:16 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
150
+ 2025-03-28 10:39:09 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
151
+ 2025-03-28 10:39:09 - Base -> Reranked
152
+ 2025-03-28 10:39:09 - MAP: 46.91 -> 65.39
153
+ 2025-03-28 10:39:09 - MRR@10: 46.19 -> 65.04
154
+ 2025-03-28 10:39:09 - NDCG@10: 52.31 -> 70.70
155
+
156
+ 2025-03-28 10:54:00 - Evaluating jinaai/jina-reranker-v2-base-multilingual
157
+ 2025-03-28 10:54:00 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
158
+ 2025-03-28 10:56:49 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
159
+ 2025-03-28 10:56:49 - Base -> Reranked
160
+ 2025-03-28 10:56:49 - MAP: 46.91 -> 61.59
161
+ 2025-03-28 10:56:49 - MRR@10: 46.19 -> 61.63
162
+ 2025-03-28 10:56:49 - NDCG@10: 52.31 -> 66.16
163
+ 2025-03-28 10:56:49 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
164
+ 2025-03-28 11:00:14 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
165
+ 2025-03-28 11:00:14 - Base -> Reranked
166
+ 2025-03-28 11:00:14 - MAP: 46.91 -> 69.46
167
+ 2025-03-28 11:00:14 - MRR@10: 46.19 -> 69.47
168
+ 2025-03-28 11:00:14 - NDCG@10: 52.31 -> 74.74
169
+
170
+ 2025-03-28 11:14:39 - Evaluating BAAI/bge-reranker-base
171
+ BAAI/bge-reranker-base
172
+ 2025-03-28 11:14:39 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
173
+ 2025-03-28 11:15:42 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
174
+ 2025-03-28 11:15:42 - Base -> Reranked
175
+ 2025-03-28 11:15:42 - MAP: 46.91 -> 52.59
176
+ 2025-03-28 11:15:42 - MRR@10: 46.19 -> 54.37
177
+ 2025-03-28 11:15:42 - NDCG@10: 52.31 -> 59.41
178
+ {'trivia-qa-dev-realistic_map': 0.5259323270998224, 'trivia-qa-dev-realistic_mrr@10': 0.5436916666666666, 'trivia-qa-dev-realistic_ndcg@10': 0.5940858318692244, 'trivia-qa-dev-realistic_base_map': 0.46913299504084743, 'trivia-qa-dev-realistic_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-realistic_base_ndcg@10': 0.5231399731095658}
179
+ 2025-03-28 11:15:42 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
180
+ 2025-03-28 11:16:58 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
181
+ 2025-03-28 11:16:58 - Base -> Reranked
182
+ 2025-03-28 11:16:58 - MAP: 46.91 -> 59.04
183
+ 2025-03-28 11:16:58 - MRR@10: 46.19 -> 62.72
184
+ 2025-03-28 11:16:58 - NDCG@10: 52.31 -> 66.61
185
+ {'trivia-qa-dev-evaluation_map': 0.590369907985238, 'trivia-qa-dev-evaluation_mrr@10': 0.6271603174603174, 'trivia-qa-dev-evaluation_ndcg@10': 0.6661314116591164, 'trivia-qa-dev-evaluation_base_map': 0.46913299504084743, 'trivia-qa-dev-evaluation_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-evaluation_base_ndcg@10': 0.5231399731095658}
186
+
187
+ 2025-03-28 11:17:01 - Evaluating BAAI/bge-reranker-large
188
+ BAAI/bge-reranker-large
189
+ 2025-03-28 11:17:01 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-realistic dataset:
190
+ 2025-03-28 11:19:48 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
191
+ 2025-03-28 11:19:48 - Base -> Reranked
192
+ 2025-03-28 11:19:48 - MAP: 46.91 -> 55.68
193
+ 2025-03-28 11:19:48 - MRR@10: 46.19 -> 56.87
194
+ 2025-03-28 11:19:48 - NDCG@10: 52.31 -> 61.76
195
+ {'trivia-qa-dev-realistic_map': 0.5567810015278374, 'trivia-qa-dev-realistic_mrr@10': 0.5687162698412699, 'trivia-qa-dev-realistic_ndcg@10': 0.6176060985342933, 'trivia-qa-dev-realistic_base_map': 0.46913299504084743, 'trivia-qa-dev-realistic_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-realistic_base_ndcg@10': 0.5231399731095658}
196
+ 2025-03-28 11:19:48 - CrossEncoderRerankingEvaluator: Evaluating the model on the trivia-qa-dev-evaluation dataset:
197
+ 2025-03-28 11:23:10 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
198
+ 2025-03-28 11:23:10 - Base -> Reranked
199
+ 2025-03-28 11:23:10 - MAP: 46.91 -> 62.93
200
+ 2025-03-28 11:23:10 - MRR@10: 46.19 -> 65.24
201
+ 2025-03-28 11:23:10 - NDCG@10: 52.31 -> 69.92
202
+ {'trivia-qa-dev-evaluation_map': 0.6292598276500135, 'trivia-qa-dev-evaluation_mrr@10': 0.6523801587301588, 'trivia-qa-dev-evaluation_ndcg@10': 0.6992497211715496, 'trivia-qa-dev-evaluation_base_map': 0.46913299504084743, 'trivia-qa-dev-evaluation_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-evaluation_base_ndcg@10': 0.5231399731095658}
203
+
204
+ bge-reranker-v2-m3
205
+ 2025-03-28 11:33:42 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
206
+ 2025-03-28 11:33:42 - Base -> Reranked
207
+ 2025-03-28 11:33:42 - MAP: 46.91 -> 59.46
208
+ 2025-03-28 11:33:42 - MRR@10: 46.19 -> 60.49
209
+ 2025-03-28 11:33:42 - NDCG@10: 52.31 -> 64.85
210
+ {'trivia-qa-dev-realistic_map': 0.5945974714456489, 'trivia-qa-dev-realistic_mrr@10': 0.6049440476190477, 'trivia-qa-dev-realistic_ndcg@10': 0.6485089522801432, 'trivia-qa-dev-realistic_base_map': 0.46913299504084743, 'trivia-qa-dev-realistic_base_mrr@10': 0.4618571428571429, 'trivia-qa-dev-realistic_base_ndcg@10': 0.523139973109566}
211
+ 2025-03-28 11:33:42 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-evaluation dataset:
212
+ 2025-03-28 11:41:37 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
213
+ 2025-03-28 11:41:37 - Base -> Reranked
214
+ 2025-03-28 11:41:37 - MAP: 46.91 -> 67.12
215
+ 2025-03-28 11:41:37 - MRR@10: 46.19 -> 68.82
216
+ 2025-03-28 11:41:37 - NDCG@10: 52.31 -> 73.29
217
+ {'gooaq-dev-evaluation_map': 0.6712088650084718, 'gooaq-dev-evaluation_mrr@10': 0.6881884920634921, 'gooaq-dev-evaluation_ndcg@10': 0.7329280539251892, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
218
+
219
+ 2025-03-28 11:41:38 - Evaluating mixedbread-ai/mxbai-rerank-xsmall-v1
220
+ mixedbread-ai/mxbai-rerank-xsmall-v1
221
+ 2025-03-28 11:41:38 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-realistic dataset:
222
+ 2025-03-28 11:42:59 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
223
+ 2025-03-28 11:42:59 - Base -> Reranked
224
+ 2025-03-28 11:42:59 - MAP: 46.91 -> 58.07
225
+ 2025-03-28 11:42:59 - MRR@10: 46.19 -> 57.79
226
+ 2025-03-28 11:42:59 - NDCG@10: 52.31 -> 62.53
227
+ {'gooaq-dev-realistic_map': 0.5806904579481978, 'gooaq-dev-realistic_mrr@10': 0.5778678571428572, 'gooaq-dev-realistic_ndcg@10': 0.6252806166705941, 'gooaq-dev-realistic_base_map': 0.46913299504084743, 'gooaq-dev-realistic_base_mrr@10': 0.4618571428571429, 'gooaq-dev-realistic_base_ndcg@10': 0.523139973109566}
228
+ 2025-03-28 11:42:59 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-evaluation dataset:
229
+ 2025-03-28 11:44:37 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
230
+ 2025-03-28 11:44:37 - Base -> Reranked
231
+ 2025-03-28 11:44:37 - MAP: 46.91 -> 63.42
232
+ 2025-03-28 11:44:37 - MRR@10: 46.19 -> 63.23
233
+ 2025-03-28 11:44:37 - NDCG@10: 52.31 -> 68.38
234
+ {'gooaq-dev-evaluation_map': 0.6341866885097522, 'gooaq-dev-evaluation_mrr@10': 0.6323373015873017, 'gooaq-dev-evaluation_ndcg@10': 0.683834674484044, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
235
+
236
+ 2025-03-28 11:44:38 - Evaluating mixedbread-ai/mxbai-rerank-base-v1
237
+ mixedbread-ai/mxbai-rerank-base-v1
238
+ 2025-03-28 11:44:38 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-realistic dataset:
239
+ 2025-03-28 11:47:02 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
240
+ 2025-03-28 11:47:02 - Base -> Reranked
241
+ 2025-03-28 11:47:02 - MAP: 46.91 -> 55.81
242
+ 2025-03-28 11:47:02 - MRR@10: 46.19 -> 55.68
243
+ 2025-03-28 11:47:02 - NDCG@10: 52.31 -> 60.99
244
+ {'gooaq-dev-realistic_map': 0.5580880756399167, 'gooaq-dev-realistic_mrr@10': 0.5567904761904762, 'gooaq-dev-realistic_ndcg@10': 0.6099184869749001, 'gooaq-dev-realistic_base_map': 0.46913299504084743, 'gooaq-dev-realistic_base_mrr@10': 0.4618571428571429, 'gooaq-dev-realistic_base_ndcg@10': 0.523139973109566}
245
+ 2025-03-28 11:47:02 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-evaluation dataset:
246
+ 2025-03-28 11:49:56 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
247
+ 2025-03-28 11:49:56 - Base -> Reranked
248
+ 2025-03-28 11:49:56 - MAP: 46.91 -> 62.12
249
+ 2025-03-28 11:49:56 - MRR@10: 46.19 -> 61.81
250
+ 2025-03-28 11:49:56 - NDCG@10: 52.31 -> 67.69
251
+ {'gooaq-dev-evaluation_map': 0.62120330763951, 'gooaq-dev-evaluation_mrr@10': 0.6180714285714286, 'gooaq-dev-evaluation_ndcg@10': 0.6769237801354084, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
252
+
253
+ 2025-03-28 11:49:57 - Evaluating mixedbread-ai/mxbai-rerank-large-v1
254
+ mixedbread-ai/mxbai-rerank-large-v1
255
+ 2025-03-28 11:49:57 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-realistic dataset:
256
+ 2025-03-28 11:56:18 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
257
+ 2025-03-28 11:56:18 - Base -> Reranked
258
+ 2025-03-28 11:56:18 - MAP: 46.91 -> 58.13
259
+ 2025-03-28 11:56:18 - MRR@10: 46.19 -> 58.63
260
+ 2025-03-28 11:56:18 - NDCG@10: 52.31 -> 63.38
261
+ {'gooaq-dev-realistic_map': 0.5813141616006278, 'gooaq-dev-realistic_mrr@10': 0.5862551587301587, 'gooaq-dev-realistic_ndcg@10': 0.6337779476332643, 'gooaq-dev-realistic_base_map': 0.46913299504084743, 'gooaq-dev-realistic_base_mrr@10': 0.4618571428571429, 'gooaq-dev-realistic_base_ndcg@10': 0.523139973109566}
262
+ 2025-03-28 11:56:18 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-evaluation dataset:
263
+ 2025-03-28 12:03:56 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
264
+ 2025-03-28 12:03:56 - Base -> Reranked
265
+ 2025-03-28 12:03:56 - MAP: 46.91 -> 64.79
266
+ 2025-03-28 12:03:56 - MRR@10: 46.19 -> 65.64
267
+ 2025-03-28 12:03:56 - NDCG@10: 52.31 -> 70.65
268
+ {'gooaq-dev-evaluation_map': 0.64787303000053, 'gooaq-dev-evaluation_mrr@10': 0.6563722222222222, 'gooaq-dev-evaluation_ndcg@10': 0.7065480874647052, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
269
+
270
+ 2025-03-28 12:03:58 - Evaluating Alibaba-NLP/gte-reranker-modernbert-base
271
+ Alibaba-NLP/gte-reranker-modernbert-base
272
+ 2025-03-28 12:03:58 - CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev-realistic dataset:
273
+ 2025-03-28 12:07:46 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
274
+ 2025-03-28 12:07:46 - Base -> Reranked
275
+ 2025-03-28 12:07:46 - MAP: 46.91 -> 61.20
276
+ 2025-03-28 12:07:46 - MRR@10: 46.19 -> 61.96
277
+ 2025-03-28 12:07:46 - NDCG@10: 52.31 -> 65.92
278
+ {'gooaq-dev-realistic_map': 0.6119823635603096, 'gooaq-dev-realistic_mrr@10': 0.6195595238095238, 'gooaq-dev-realistic_ndcg@10': 0.6591716946749766, 'gooaq-dev-realistic_base_map': 0.46913299504084743, 'gooaq-dev-realistic_base_mrr@10': 0.4618571428571429, 'gooaq-dev-realistic_base_ndcg@10': 0.523139973109566}
279
+ 2025-03-28 12:12:23 - Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 29.0, Mean 29.2, Max 30.0
280
+ 2025-03-28 12:12:23 - Base -> Reranked
281
+ 2025-03-28 12:12:23 - MAP: 46.91 -> 69.22
282
+ 2025-03-28 12:12:23 - MRR@10: 46.19 -> 70.84
283
+ 2025-03-28 12:12:23 - NDCG@10: 52.31 -> 74.88
284
+ {'gooaq-dev-evaluation_map': 0.6921869586576195, 'gooaq-dev-evaluation_mrr@10': 0.7084460317460317, 'gooaq-dev-evaluation_ndcg@10': 0.7487587539162411, 'gooaq-dev-evaluation_base_map': 0.46913299504084743, 'gooaq-dev-evaluation_base_mrr@10': 0.4618571428571429, 'gooaq-dev-evaluation_base_ndcg@10': 0.523139973109566}
285
+ """
training_trivia_qa_bce.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import traceback
3
+
4
+ import torch
5
+ from datasets import load_dataset
6
+
7
+ from sentence_transformers import SentenceTransformer
8
+ from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderModelCardData
9
+ from sentence_transformers.cross_encoder.evaluation import (
10
+ CrossEncoderNanoBEIREvaluator,
11
+ CrossEncoderRerankingEvaluator,
12
+ )
13
+ from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
14
+ from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
15
+ from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
16
+ from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator
17
+ from sentence_transformers.util import mine_hard_negatives
18
+
19
+ # Set the log level to INFO to get more information
20
+ logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
21
+
22
+
23
+ def main():
24
+ model_name = "answerdotai/ModernBERT-base"
25
+
26
+ train_batch_size = 16
27
+ num_epochs = 1
28
+ num_hard_negatives = 5 # How many hard negatives should be mined for each question-answer pair
29
+
30
+ # 1a. Load a model to finetune with 1b. (Optional) model card data
31
+ model = CrossEncoder(
32
+ model_name,
33
+ model_card_data=CrossEncoderModelCardData(
34
+ language="en",
35
+ license="apache-2.0",
36
+ model_name="ModernBERT-base trained on TriviaQA",
37
+ ),
38
+ )
39
+ print("Model max length:", model.max_length)
40
+ print("Model num labels:", model.num_labels)
41
+
42
+ # 2a. Load the TriviaQA dataset: https://huggingface.co/datasets/sentence-transformers/trivia-qa
43
+ logging.info("Read the trivia-qa training dataset")
44
+ full_dataset = load_dataset("sentence-transformers/trivia-qa", split="train")
45
+ eval_dataset = full_dataset.select(range(1000)) # Use the first 1000 samples for evaluation
46
+ train_dataset = full_dataset.select(range(1000, len(full_dataset))) # Use the rest of the samples for training
47
+ logging.info(train_dataset)
48
+ logging.info(eval_dataset)
49
+
50
+ # 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
51
+ embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
52
+ hard_train_dataset = mine_hard_negatives(
53
+ train_dataset,
54
+ embedding_model,
55
+ num_negatives=num_hard_negatives, # How many negatives per question-answer pair
56
+ margin=0, # Similarity between query and negative samples should be x lower than query-positive similarity
57
+ range_min=0, # Skip the x most similar samples
58
+ range_max=100, # Consider only the x most similar samples
59
+ sampling_strategy="top", # Sample the top negatives from the range
60
+ batch_size=4096, # Use a batch size of 4096 for the embedding model
61
+ output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
62
+ use_faiss=True,
63
+ )
64
+ logging.info(hard_train_dataset)
65
+
66
+ # 2c. (Optionally) Save the hard training dataset to disk
67
+ # hard_train_dataset.save_to_disk("trivia-qa-hard-train")
68
+ # Load again with:
69
+ # hard_train_dataset = load_from_disk("trivia-qa-hard-train")
70
+
71
+ # 3. Define our training loss.
72
+ # pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
73
+ loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
74
+
75
+ # 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
76
+ nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
77
+ dataset_names=["msmarco", "nfcorpus", "nq"],
78
+ batch_size=train_batch_size,
79
+ )
80
+
81
+ # 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
82
+ # We include the positive answer in the list of negatives, so the evaluator can use the performance of the
83
+ # embedding model as a baseline.
84
+ hard_eval_dataset = mine_hard_negatives(
85
+ eval_dataset,
86
+ embedding_model,
87
+ corpus=full_dataset["answer"], # Use the full dataset as the corpus
88
+ num_negatives=30, # How many documents to rerank
89
+ batch_size=4096,
90
+ include_positives=True,
91
+ output_format="n-tuple",
92
+ use_faiss=True,
93
+ )
94
+ logging.info(hard_eval_dataset)
95
+ reranking_evaluator = CrossEncoderRerankingEvaluator(
96
+ samples=[
97
+ {
98
+ "query": sample["query"],
99
+ "positive": [sample["answer"]],
100
+ "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
101
+ }
102
+ for sample in hard_eval_dataset
103
+ ],
104
+ batch_size=train_batch_size,
105
+ name="trivia-qa-dev",
106
+ always_rerank_positives=False,
107
+ )
108
+
109
+ # 4c. Combine the evaluators & run the base model on them
110
+ evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
111
+ evaluator(model)
112
+
113
+ # 5. Define the training arguments
114
+ short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
115
+ run_name = f"reranker-{short_model_name}-trivia-qa-bce"
116
+ args = CrossEncoderTrainingArguments(
117
+ # Required parameter:
118
+ output_dir=f"models/{run_name}",
119
+ # Optional training parameters:
120
+ num_train_epochs=num_epochs,
121
+ per_device_train_batch_size=train_batch_size,
122
+ per_device_eval_batch_size=train_batch_size,
123
+ learning_rate=2e-5,
124
+ warmup_ratio=0.1,
125
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
126
+ bf16=True, # Set to True if you have a GPU that supports BF16
127
+ dataloader_num_workers=4,
128
+ load_best_model_at_end=True,
129
+ metric_for_best_model="eval_trivia-qa-dev_ndcg@10",
130
+ # Optional tracking/debugging parameters:
131
+ eval_strategy="steps",
132
+ eval_steps=1000,
133
+ save_strategy="steps",
134
+ save_steps=1000,
135
+ save_total_limit=2,
136
+ logging_steps=200,
137
+ logging_first_step=True,
138
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
139
+ seed=12,
140
+ )
141
+
142
+ # 6. Create the trainer & start training
143
+ trainer = CrossEncoderTrainer(
144
+ model=model,
145
+ args=args,
146
+ train_dataset=hard_train_dataset,
147
+ loss=loss,
148
+ evaluator=evaluator,
149
+ )
150
+ trainer.train()
151
+
152
+ # 7. Evaluate the final model, useful to include these in the model card
153
+ evaluator(model)
154
+
155
+ # 8. Save the final model
156
+ final_output_dir = f"models/{run_name}/final"
157
+ model.save_pretrained(final_output_dir)
158
+
159
+ # 9. (Optional) save the model to the Hugging Face Hub!
160
+ # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
161
+ try:
162
+ model.push_to_hub(run_name)
163
+ except Exception:
164
+ logging.error(
165
+ f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
166
+ f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
167
+ f"and saving it using `model.push_to_hub('{run_name}')`."
168
+ )
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()