Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
feat-add-versions-to-benchmarks
#28
by
nan
- opened
- Makefile +10 -0
- app.py +349 -289
- pyproject.toml +3 -3
- requirements.txt +2 -2
- src/about.py +1 -1
- src/benchmarks.py +44 -65
- src/{display/utils.py → columns.py} +53 -40
- src/{display/gradio_formatting.py → components.py} +24 -19
- src/{display/css_html_js.py → css_html_js.py} +0 -0
- src/display/formatting.py +0 -29
- src/display/gradio_listener.py +0 -53
- src/envs.py +40 -6
- src/loaders.py +88 -0
- src/{read_evals.py → models.py} +68 -122
- src/utils.py +267 -136
- tests/src/display/test_utils.py +0 -23
- tests/src/test_benchmarks.py +29 -5
- tests/src/test_columns.py +119 -0
- tests/src/test_envs.py +14 -0
- tests/src/test_loaders.py +46 -0
- tests/src/test_models.py +89 -0
- tests/src/test_read_evals.py +0 -68
- tests/src/test_utils.py +237 -0
- tests/test_utils.py +0 -115
- tests/toydata/eval_results/AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json +0 -0
- tests/toydata/eval_results/AIR-Bench_24.05/bge-m3/NoReranker/results.json +0 -0
- tests/toydata/test_data.json +0 -98
- tests/toydata/test_results/bge-m3/NoReranker/results_2023-11-21T18-10-08.json +0 -98
- tests/toydata/test_results/bge-m3/bge-reranker-v2-m3/results_2023-11-21T18-10-08.json +0 -98
Makefile
CHANGED
@@ -3,11 +3,21 @@
|
|
3 |
|
4 |
style:
|
5 |
python -m black --line-length 119 .
|
|
|
6 |
python -m isort .
|
|
|
7 |
ruff check --fix .
|
|
|
8 |
|
9 |
|
10 |
quality:
|
11 |
python -m black --check --line-length 119 .
|
|
|
12 |
python -m isort --check-only .
|
|
|
13 |
ruff check .
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
style:
|
5 |
python -m black --line-length 119 .
|
6 |
+
python -m black --line-length 119 src
|
7 |
python -m isort .
|
8 |
+
python -m isort src
|
9 |
ruff check --fix .
|
10 |
+
ruff check --fix src
|
11 |
|
12 |
|
13 |
quality:
|
14 |
python -m black --check --line-length 119 .
|
15 |
+
python -m black --check --line-length 119 src
|
16 |
python -m isort --check-only .
|
17 |
+
python -m isort --check-only src
|
18 |
ruff check .
|
19 |
+
ruff check src
|
20 |
+
|
21 |
+
|
22 |
+
test:
|
23 |
+
python -m pytest tests
|
app.py
CHANGED
@@ -1,131 +1,141 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from apscheduler.schedulers.background import BackgroundScheduler
|
3 |
from huggingface_hub import snapshot_download
|
4 |
|
5 |
-
from src.about import
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
)
|
20 |
-
from src.display.css_html_js import custom_css
|
21 |
-
from src.display.utils import (
|
22 |
-
COL_NAME_IS_ANONYMOUS,
|
23 |
-
COL_NAME_REVISION,
|
24 |
-
COL_NAME_TIMESTAMP,
|
25 |
-
COL_NAME_RERANKING_MODEL,
|
26 |
-
COL_NAME_RETRIEVAL_MODEL
|
27 |
)
|
|
|
28 |
from src.envs import (
|
29 |
API,
|
|
|
|
|
|
|
30 |
EVAL_RESULTS_PATH,
|
|
|
|
|
31 |
REPO_ID,
|
32 |
RESULTS_REPO,
|
33 |
TOKEN,
|
34 |
-
BM25_LINK,
|
35 |
-
BENCHMARK_VERSION_LIST,
|
36 |
-
LATEST_BENCHMARK_VERSION
|
37 |
-
)
|
38 |
-
from src.read_evals import (
|
39 |
-
get_raw_eval_results,
|
40 |
-
get_leaderboard_df
|
41 |
-
)
|
42 |
-
from src.utils import (
|
43 |
-
update_metric,
|
44 |
-
upload_file,
|
45 |
-
get_default_cols,
|
46 |
-
submit_results,
|
47 |
-
reset_rank,
|
48 |
-
remove_html
|
49 |
-
)
|
50 |
-
from src.display.gradio_formatting import (
|
51 |
-
get_version_dropdown,
|
52 |
-
get_search_bar,
|
53 |
-
get_reranking_dropdown,
|
54 |
-
get_metric_dropdown,
|
55 |
-
get_domain_dropdown,
|
56 |
-
get_language_dropdown,
|
57 |
-
get_anonymous_checkbox,
|
58 |
-
get_revision_and_ts_checkbox,
|
59 |
-
get_leaderboard_table,
|
60 |
-
get_noreranking_dropdown
|
61 |
)
|
62 |
-
from src.
|
|
|
|
|
|
|
63 |
|
64 |
def restart_space():
|
65 |
API.restart_space(repo_id=REPO_ID)
|
66 |
|
67 |
|
68 |
try:
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
restart_space()
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
leaderboard_df_qa = leaderboard_df_qa[~leaderboard_df_qa[COL_NAME_IS_ANONYMOUS]][shown_columns_qa]
|
92 |
-
leaderboard_df_qa.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
93 |
-
|
94 |
-
leaderboard_df_long_doc = original_df_long_doc.copy()
|
95 |
-
shown_columns_long_doc, types_long_doc = get_default_cols(
|
96 |
-
'long-doc', leaderboard_df_long_doc.columns, add_fix_cols=True)
|
97 |
-
leaderboard_df_long_doc = leaderboard_df_long_doc[~leaderboard_df_long_doc[COL_NAME_IS_ANONYMOUS]][shown_columns_long_doc]
|
98 |
-
leaderboard_df_long_doc.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
99 |
-
|
100 |
-
# select reranking model
|
101 |
-
reranking_models = sorted(list(frozenset([eval_result.reranking_model for eval_result in raw_data])))
|
102 |
-
|
103 |
-
|
104 |
-
def update_metric_qa(
|
105 |
-
metric: str,
|
106 |
-
domains: list,
|
107 |
-
langs: list,
|
108 |
-
reranking_model: list,
|
109 |
-
query: str,
|
110 |
-
show_anonymous: bool,
|
111 |
-
show_revision_and_timestamp,
|
112 |
):
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
122 |
show_revision_and_timestamp,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
):
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
|
127 |
demo = gr.Blocks(css=custom_css)
|
128 |
|
|
|
|
|
129 |
with demo:
|
130 |
gr.HTML(TITLE)
|
131 |
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
@@ -133,25 +143,24 @@ with demo:
|
|
133 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
134 |
with gr.TabItem("Results", elem_id="results-tab-table"):
|
135 |
with gr.Row():
|
136 |
-
|
137 |
|
138 |
with gr.TabItem("QA", elem_id="qa-benchmark-tab-table", id=0):
|
139 |
with gr.Row():
|
140 |
with gr.Column(min_width=320):
|
141 |
# select domain
|
142 |
with gr.Row():
|
143 |
-
|
144 |
# select language
|
145 |
with gr.Row():
|
146 |
-
|
147 |
-
|
148 |
with gr.Column():
|
149 |
# select the metric
|
150 |
-
|
151 |
with gr.Row():
|
152 |
show_anonymous = get_anonymous_checkbox()
|
153 |
with gr.Row():
|
154 |
-
|
155 |
with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
|
156 |
with gr.TabItem("Retrieval + Reranking", id=10):
|
157 |
with gr.Row():
|
@@ -160,273 +169,327 @@ with demo:
|
|
160 |
search_bar = get_search_bar()
|
161 |
# select reranking models
|
162 |
with gr.Column():
|
163 |
-
|
164 |
-
|
|
|
165 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
set_listeners(
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
search_bar,
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
176 |
show_anonymous,
|
177 |
-
|
178 |
)
|
179 |
|
180 |
# set metric listener
|
181 |
-
|
182 |
-
|
183 |
-
[
|
184 |
-
|
185 |
-
|
186 |
-
selected_langs,
|
187 |
-
selected_rerankings,
|
188 |
-
search_bar,
|
189 |
-
show_anonymous,
|
190 |
-
show_revision_and_timestamp,
|
191 |
-
],
|
192 |
-
leaderboard_table,
|
193 |
-
queue=True
|
194 |
)
|
|
|
195 |
with gr.TabItem("Retrieval Only", id=11):
|
196 |
with gr.Row():
|
197 |
with gr.Column(scale=1):
|
198 |
-
|
199 |
with gr.Column(scale=1):
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
204 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
set_listeners(
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
217 |
show_anonymous,
|
218 |
-
|
219 |
)
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
update_metric_qa,
|
224 |
[
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
show_anonymous,
|
231 |
-
|
232 |
],
|
233 |
-
|
234 |
-
queue=True
|
235 |
)
|
|
|
236 |
with gr.TabItem("Reranking Only", id=12):
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
with gr.Row():
|
241 |
with gr.Column(scale=1):
|
242 |
-
|
243 |
with gr.Column(scale=1):
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
)
|
251 |
|
252 |
set_listeners(
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
260 |
show_anonymous,
|
261 |
-
|
262 |
)
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
[
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
show_anonymous,
|
273 |
-
|
274 |
],
|
275 |
-
|
276 |
-
queue=True
|
277 |
)
|
278 |
with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
|
279 |
with gr.Row():
|
280 |
with gr.Column(min_width=320):
|
281 |
# select domain
|
282 |
with gr.Row():
|
283 |
-
|
284 |
# select language
|
285 |
with gr.Row():
|
286 |
-
|
287 |
-
LANG_COLS_LONG_DOC, LANG_COLS_LONG_DOC
|
288 |
-
)
|
289 |
with gr.Column():
|
290 |
# select the metric
|
291 |
with gr.Row():
|
292 |
-
|
293 |
with gr.Row():
|
294 |
show_anonymous = get_anonymous_checkbox()
|
295 |
with gr.Row():
|
296 |
-
|
297 |
-
with gr.Tabs(elem_classes="tab-buttons")
|
298 |
with gr.TabItem("Retrieval + Reranking", id=20):
|
299 |
with gr.Row():
|
300 |
with gr.Column():
|
301 |
search_bar = get_search_bar()
|
302 |
-
# select reranking model
|
303 |
with gr.Column():
|
304 |
-
|
305 |
|
306 |
-
|
307 |
-
leaderboard_df_long_doc, types_long_doc
|
308 |
-
)
|
309 |
|
310 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
)
|
314 |
|
315 |
set_listeners(
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
search_bar,
|
320 |
-
|
321 |
-
|
322 |
-
|
|
|
323 |
show_anonymous,
|
324 |
-
|
325 |
)
|
326 |
|
327 |
# set metric listener
|
328 |
-
|
329 |
-
|
330 |
[
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
search_bar,
|
336 |
show_anonymous,
|
337 |
-
|
338 |
],
|
339 |
-
|
340 |
-
queue=True
|
341 |
)
|
342 |
with gr.TabItem("Retrieval Only", id=21):
|
343 |
with gr.Row():
|
344 |
with gr.Column(scale=1):
|
345 |
-
|
346 |
with gr.Column(scale=1):
|
347 |
-
|
348 |
-
|
349 |
-
|
|
|
350 |
]
|
351 |
-
|
352 |
-
|
353 |
-
|
|
|
|
|
354 |
]
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
|
|
|
|
|
|
|
|
360 |
)
|
361 |
|
362 |
set_listeners(
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
370 |
show_anonymous,
|
371 |
-
|
372 |
)
|
373 |
|
374 |
-
|
375 |
-
|
376 |
[
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
show_anonymous,
|
383 |
-
|
384 |
],
|
385 |
-
|
386 |
-
queue=True
|
387 |
)
|
388 |
with gr.TabItem("Reranking Only", id=22):
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
|
|
|
|
394 |
with gr.Row():
|
395 |
with gr.Column(scale=1):
|
396 |
-
|
397 |
with gr.Column(scale=1):
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
)
|
405 |
|
406 |
set_listeners(
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
414 |
show_anonymous,
|
415 |
-
|
416 |
)
|
417 |
-
|
418 |
-
|
|
|
419 |
[
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
show_anonymous,
|
426 |
-
|
427 |
],
|
428 |
-
|
429 |
-
queue=True
|
430 |
)
|
431 |
|
432 |
with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
|
@@ -443,23 +506,18 @@ with demo:
|
|
443 |
with gr.Row():
|
444 |
with gr.Column():
|
445 |
reranking_model_name = gr.Textbox(
|
446 |
-
label="Reranking Model name",
|
447 |
-
info="Optional",
|
448 |
-
value="NoReranker"
|
449 |
)
|
450 |
with gr.Column():
|
451 |
-
reranking_model_url = gr.Textbox(
|
452 |
-
label="Reranking Model URL",
|
453 |
-
info="Optional",
|
454 |
-
value=""
|
455 |
-
)
|
456 |
with gr.Row():
|
457 |
with gr.Column():
|
458 |
benchmark_version = gr.Dropdown(
|
459 |
BENCHMARK_VERSION_LIST,
|
460 |
value=LATEST_BENCHMARK_VERSION,
|
461 |
interactive=True,
|
462 |
-
label="AIR-Bench Version"
|
|
|
463 |
with gr.Row():
|
464 |
upload_button = gr.UploadButton("Click to upload search results", file_count="single")
|
465 |
with gr.Row():
|
@@ -468,7 +526,8 @@ with demo:
|
|
468 |
is_anonymous = gr.Checkbox(
|
469 |
label="Nope. I want to submit anonymously 🥷",
|
470 |
value=False,
|
471 |
-
info="Do you want to shown on the leaderboard by default?"
|
|
|
472 |
with gr.Row():
|
473 |
submit_button = gr.Button("Submit")
|
474 |
with gr.Row():
|
@@ -478,7 +537,8 @@ with demo:
|
|
478 |
[
|
479 |
upload_button,
|
480 |
],
|
481 |
-
file_output
|
|
|
482 |
submit_button.click(
|
483 |
submit_results,
|
484 |
[
|
@@ -488,10 +548,10 @@ with demo:
|
|
488 |
reranking_model_name,
|
489 |
reranking_model_url,
|
490 |
benchmark_version,
|
491 |
-
is_anonymous
|
492 |
],
|
493 |
submission_result,
|
494 |
-
show_progress="hidden"
|
495 |
)
|
496 |
|
497 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import gradio as gr
|
4 |
from apscheduler.schedulers.background import BackgroundScheduler
|
5 |
from huggingface_hub import snapshot_download
|
6 |
|
7 |
+
from src.about import BENCHMARKS_TEXT, EVALUATION_QUEUE_TEXT, INTRODUCTION_TEXT, TITLE
|
8 |
+
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
9 |
+
from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
|
10 |
+
from src.components import (
|
11 |
+
get_anonymous_checkbox,
|
12 |
+
get_domain_dropdown,
|
13 |
+
get_language_dropdown,
|
14 |
+
get_leaderboard_table,
|
15 |
+
get_metric_dropdown,
|
16 |
+
get_noreranking_dropdown,
|
17 |
+
get_reranking_dropdown,
|
18 |
+
get_revision_and_ts_checkbox,
|
19 |
+
get_search_bar,
|
20 |
+
get_version_dropdown,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
)
|
22 |
+
from src.css_html_js import custom_css
|
23 |
from src.envs import (
|
24 |
API,
|
25 |
+
BENCHMARK_VERSION_LIST,
|
26 |
+
DEFAULT_METRIC_LONG_DOC,
|
27 |
+
DEFAULT_METRIC_QA,
|
28 |
EVAL_RESULTS_PATH,
|
29 |
+
LATEST_BENCHMARK_VERSION,
|
30 |
+
METRIC_LIST,
|
31 |
REPO_ID,
|
32 |
RESULTS_REPO,
|
33 |
TOKEN,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
)
|
35 |
+
from src.loaders import load_eval_results
|
36 |
+
from src.models import TaskType, model_hyperlink
|
37 |
+
from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
|
38 |
+
|
39 |
|
40 |
def restart_space():
|
41 |
API.restart_space(repo_id=REPO_ID)
|
42 |
|
43 |
|
44 |
try:
|
45 |
+
if not os.environ.get("LOCAL_MODE", False):
|
46 |
+
print("Running in local mode")
|
47 |
+
snapshot_download(
|
48 |
+
repo_id=RESULTS_REPO,
|
49 |
+
local_dir=EVAL_RESULTS_PATH,
|
50 |
+
repo_type="dataset",
|
51 |
+
tqdm_class=None,
|
52 |
+
etag_timeout=30,
|
53 |
+
token=TOKEN,
|
54 |
+
)
|
55 |
+
except Exception:
|
56 |
+
print("failed to download")
|
57 |
restart_space()
|
58 |
|
59 |
+
global ds_dict
|
60 |
+
ds_dict = load_eval_results(EVAL_RESULTS_PATH)
|
61 |
+
global datastore
|
62 |
+
datastore = ds_dict[LATEST_BENCHMARK_VERSION]
|
63 |
+
|
64 |
+
|
65 |
+
def update_qa_metric(
|
66 |
+
metric: str,
|
67 |
+
domains: list,
|
68 |
+
langs: list,
|
69 |
+
reranking_model: list,
|
70 |
+
query: str,
|
71 |
+
show_anonymous: bool,
|
72 |
+
show_revision_and_timestamp: bool,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
):
|
74 |
+
global datastore
|
75 |
+
return update_metric(
|
76 |
+
datastore,
|
77 |
+
TaskType.qa,
|
78 |
+
metric,
|
79 |
+
domains,
|
80 |
+
langs,
|
81 |
+
reranking_model,
|
82 |
+
query,
|
83 |
+
show_anonymous,
|
84 |
show_revision_and_timestamp,
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
def update_doc_metric(
|
89 |
+
metric: str,
|
90 |
+
domains: list,
|
91 |
+
langs: list,
|
92 |
+
reranking_model: list,
|
93 |
+
query: str,
|
94 |
+
show_anonymous: bool,
|
95 |
+
show_revision_and_timestamp,
|
96 |
):
|
97 |
+
global datastore
|
98 |
+
return update_metric(
|
99 |
+
datastore,
|
100 |
+
TaskType.long_doc,
|
101 |
+
metric,
|
102 |
+
domains,
|
103 |
+
langs,
|
104 |
+
reranking_model,
|
105 |
+
query,
|
106 |
+
show_anonymous,
|
107 |
+
show_revision_and_timestamp,
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def update_qa_version(version):
|
112 |
+
global datastore
|
113 |
+
global ds_dict
|
114 |
+
datastore = ds_dict[version]
|
115 |
+
domain_elem = get_domain_dropdown(QABenchmarks[datastore.slug])
|
116 |
+
lang_elem = get_language_dropdown(QABenchmarks[datastore.slug])
|
117 |
+
model_elem = get_reranking_dropdown(datastore.reranking_models)
|
118 |
+
df_elem = get_leaderboard_table(datastore.qa_fmt_df, datastore.qa_types)
|
119 |
+
hidden_df_elem = get_leaderboard_table(datastore.qa_raw_df, datastore.qa_types, visible=False)
|
120 |
+
return domain_elem, lang_elem, model_elem, df_elem, hidden_df_elem
|
121 |
+
|
122 |
+
|
123 |
+
def update_doc_version(version):
|
124 |
+
global datastore
|
125 |
+
global ds_dict
|
126 |
+
datastore = ds_dict[version]
|
127 |
+
domain_elem = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
|
128 |
+
lang_elem = get_language_dropdown(LongDocBenchmarks[datastore.slug])
|
129 |
+
model_elem = get_reranking_dropdown(datastore.reranking_models)
|
130 |
+
df_elem = get_leaderboard_table(datastore.doc_fmt_df, datastore.doc_types)
|
131 |
+
hidden_df_elem = get_leaderboard_table(datastore.doc_raw_df, datastore.doc_types, visible=False)
|
132 |
+
return domain_elem, lang_elem, model_elem, df_elem, hidden_df_elem
|
133 |
|
134 |
|
135 |
demo = gr.Blocks(css=custom_css)
|
136 |
|
137 |
+
BM25_LINK = model_hyperlink("https://github.com/castorini/pyserini", "BM25")
|
138 |
+
|
139 |
with demo:
|
140 |
gr.HTML(TITLE)
|
141 |
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
|
|
143 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
144 |
with gr.TabItem("Results", elem_id="results-tab-table"):
|
145 |
with gr.Row():
|
146 |
+
version = get_version_dropdown()
|
147 |
|
148 |
with gr.TabItem("QA", elem_id="qa-benchmark-tab-table", id=0):
|
149 |
with gr.Row():
|
150 |
with gr.Column(min_width=320):
|
151 |
# select domain
|
152 |
with gr.Row():
|
153 |
+
domains = get_domain_dropdown(QABenchmarks[datastore.slug])
|
154 |
# select language
|
155 |
with gr.Row():
|
156 |
+
langs = get_language_dropdown(QABenchmarks[datastore.slug])
|
|
|
157 |
with gr.Column():
|
158 |
# select the metric
|
159 |
+
metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_QA)
|
160 |
with gr.Row():
|
161 |
show_anonymous = get_anonymous_checkbox()
|
162 |
with gr.Row():
|
163 |
+
show_rev_ts = get_revision_and_ts_checkbox()
|
164 |
with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
|
165 |
with gr.TabItem("Retrieval + Reranking", id=10):
|
166 |
with gr.Row():
|
|
|
169 |
search_bar = get_search_bar()
|
170 |
# select reranking models
|
171 |
with gr.Column():
|
172 |
+
models = get_reranking_dropdown(datastore.reranking_models)
|
173 |
+
# shown_table
|
174 |
+
qa_df_elem_ret_rerank = get_leaderboard_table(datastore.qa_fmt_df, datastore.qa_types)
|
175 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
176 |
+
qa_df_elem_ret_rerank_hidden = get_leaderboard_table(
|
177 |
+
datastore.qa_raw_df, datastore.qa_types, visible=False
|
178 |
+
)
|
179 |
+
|
180 |
+
version.change(
|
181 |
+
update_qa_version,
|
182 |
+
version,
|
183 |
+
[domains, langs, models, qa_df_elem_ret_rerank, qa_df_elem_ret_rerank_hidden],
|
184 |
+
)
|
185 |
|
186 |
set_listeners(
|
187 |
+
TaskType.qa,
|
188 |
+
qa_df_elem_ret_rerank,
|
189 |
+
qa_df_elem_ret_rerank_hidden,
|
190 |
search_bar,
|
191 |
+
version,
|
192 |
+
domains,
|
193 |
+
langs,
|
194 |
+
models,
|
195 |
show_anonymous,
|
196 |
+
show_rev_ts,
|
197 |
)
|
198 |
|
199 |
# set metric listener
|
200 |
+
metric.change(
|
201 |
+
update_qa_metric,
|
202 |
+
[metric, domains, langs, models, search_bar, show_anonymous, show_rev_ts],
|
203 |
+
qa_df_elem_ret_rerank,
|
204 |
+
queue=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
)
|
206 |
+
|
207 |
with gr.TabItem("Retrieval Only", id=11):
|
208 |
with gr.Row():
|
209 |
with gr.Column(scale=1):
|
210 |
+
search_bar_ret = get_search_bar()
|
211 |
with gr.Column(scale=1):
|
212 |
+
models_ret = get_noreranking_dropdown()
|
213 |
+
|
214 |
+
_qa_df_ret = datastore.qa_fmt_df[datastore.qa_fmt_df[COL_NAME_RERANKING_MODEL] == "NoReranker"]
|
215 |
+
_qa_df_ret = reset_rank(_qa_df_ret)
|
216 |
+
qa_df_elem_ret = get_leaderboard_table(_qa_df_ret, datastore.qa_types)
|
217 |
+
|
218 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
219 |
+
_qa_df_ret_hidden = datastore.qa_raw_df[
|
220 |
+
datastore.qa_raw_df[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
221 |
+
]
|
222 |
+
_qa_df_ret_hidden = reset_rank(_qa_df_ret_hidden)
|
223 |
+
qa_df_elem_ret_hidden = get_leaderboard_table(
|
224 |
+
_qa_df_ret_hidden, datastore.qa_types, visible=False
|
225 |
+
)
|
226 |
+
|
227 |
+
version.change(
|
228 |
+
update_qa_version,
|
229 |
+
version,
|
230 |
+
[
|
231 |
+
domains,
|
232 |
+
langs,
|
233 |
+
models_ret,
|
234 |
+
qa_df_elem_ret,
|
235 |
+
qa_df_elem_ret_hidden,
|
236 |
+
],
|
237 |
+
)
|
238 |
|
239 |
set_listeners(
|
240 |
+
TaskType.qa,
|
241 |
+
qa_df_elem_ret,
|
242 |
+
qa_df_elem_ret_hidden,
|
243 |
+
search_bar_ret,
|
244 |
+
version,
|
245 |
+
domains,
|
246 |
+
langs,
|
247 |
+
models_ret,
|
248 |
show_anonymous,
|
249 |
+
show_rev_ts,
|
250 |
)
|
251 |
|
252 |
+
metric.change(
|
253 |
+
update_qa_metric,
|
|
|
254 |
[
|
255 |
+
metric,
|
256 |
+
domains,
|
257 |
+
langs,
|
258 |
+
models_ret,
|
259 |
+
search_bar_ret,
|
260 |
show_anonymous,
|
261 |
+
show_rev_ts,
|
262 |
],
|
263 |
+
qa_df_elem_ret,
|
264 |
+
queue=True,
|
265 |
)
|
266 |
+
|
267 |
with gr.TabItem("Reranking Only", id=12):
|
268 |
+
_qa_df_rerank = datastore.qa_fmt_df[datastore.qa_fmt_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
|
269 |
+
_qa_df_rerank = reset_rank(_qa_df_rerank)
|
270 |
+
qa_rerank_models = _qa_df_rerank[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
|
271 |
with gr.Row():
|
272 |
with gr.Column(scale=1):
|
273 |
+
qa_models_rerank = get_reranking_dropdown(qa_rerank_models)
|
274 |
with gr.Column(scale=1):
|
275 |
+
qa_search_bar_rerank = gr.Textbox(show_label=False, visible=False)
|
276 |
+
qa_df_elem_rerank = get_leaderboard_table(_qa_df_rerank, datastore.qa_types)
|
277 |
+
|
278 |
+
_qa_df_rerank_hidden = datastore.qa_raw_df[
|
279 |
+
datastore.qa_raw_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
|
280 |
+
]
|
281 |
+
_qa_df_rerank_hidden = reset_rank(_qa_df_rerank_hidden)
|
282 |
+
qa_df_elem_rerank_hidden = get_leaderboard_table(
|
283 |
+
_qa_df_rerank_hidden, datastore.qa_types, visible=False
|
284 |
+
)
|
285 |
+
|
286 |
+
version.change(
|
287 |
+
update_qa_version,
|
288 |
+
version,
|
289 |
+
[domains, langs, qa_models_rerank, qa_df_elem_rerank, qa_df_elem_rerank_hidden],
|
290 |
)
|
291 |
|
292 |
set_listeners(
|
293 |
+
TaskType.qa,
|
294 |
+
qa_df_elem_rerank,
|
295 |
+
qa_df_elem_rerank_hidden,
|
296 |
+
qa_search_bar_rerank,
|
297 |
+
version,
|
298 |
+
domains,
|
299 |
+
langs,
|
300 |
+
qa_models_rerank,
|
301 |
show_anonymous,
|
302 |
+
show_rev_ts,
|
303 |
)
|
304 |
+
|
305 |
+
metric.change(
|
306 |
+
update_qa_metric,
|
307 |
[
|
308 |
+
metric,
|
309 |
+
domains,
|
310 |
+
langs,
|
311 |
+
qa_models_rerank,
|
312 |
+
qa_search_bar_rerank,
|
313 |
show_anonymous,
|
314 |
+
show_rev_ts,
|
315 |
],
|
316 |
+
qa_df_elem_rerank,
|
317 |
+
queue=True,
|
318 |
)
|
319 |
with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
|
320 |
with gr.Row():
|
321 |
with gr.Column(min_width=320):
|
322 |
# select domain
|
323 |
with gr.Row():
|
324 |
+
domains = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
|
325 |
# select language
|
326 |
with gr.Row():
|
327 |
+
langs = get_language_dropdown(LongDocBenchmarks[datastore.slug])
|
|
|
|
|
328 |
with gr.Column():
|
329 |
# select the metric
|
330 |
with gr.Row():
|
331 |
+
metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_LONG_DOC)
|
332 |
with gr.Row():
|
333 |
show_anonymous = get_anonymous_checkbox()
|
334 |
with gr.Row():
|
335 |
+
show_rev_ts = get_revision_and_ts_checkbox()
|
336 |
+
with gr.Tabs(elem_classes="tab-buttons"):
|
337 |
with gr.TabItem("Retrieval + Reranking", id=20):
|
338 |
with gr.Row():
|
339 |
with gr.Column():
|
340 |
search_bar = get_search_bar()
|
|
|
341 |
with gr.Column():
|
342 |
+
models = get_reranking_dropdown(datastore.reranking_models)
|
343 |
|
344 |
+
doc_df_elem_ret_rerank = get_leaderboard_table(datastore.doc_fmt_df, datastore.doc_types)
|
|
|
|
|
345 |
|
346 |
# Dummy leaderboard for handling the case when the user uses backspace key
|
347 |
+
doc_df_elem_ret_rerank_hidden = get_leaderboard_table(
|
348 |
+
datastore.doc_raw_df, datastore.doc_types, visible=False
|
349 |
+
)
|
350 |
+
|
351 |
+
version.change(
|
352 |
+
update_doc_version,
|
353 |
+
version,
|
354 |
+
[domains, langs, models, doc_df_elem_ret_rerank, doc_df_elem_ret_rerank_hidden],
|
355 |
)
|
356 |
|
357 |
set_listeners(
|
358 |
+
TaskType.long_doc,
|
359 |
+
doc_df_elem_ret_rerank,
|
360 |
+
doc_df_elem_ret_rerank_hidden,
|
361 |
search_bar,
|
362 |
+
version,
|
363 |
+
domains,
|
364 |
+
langs,
|
365 |
+
models,
|
366 |
show_anonymous,
|
367 |
+
show_rev_ts,
|
368 |
)
|
369 |
|
370 |
# set metric listener
|
371 |
+
metric.change(
|
372 |
+
update_doc_metric,
|
373 |
[
|
374 |
+
metric,
|
375 |
+
domains,
|
376 |
+
langs,
|
377 |
+
models,
|
378 |
search_bar,
|
379 |
show_anonymous,
|
380 |
+
show_rev_ts,
|
381 |
],
|
382 |
+
doc_df_elem_ret_rerank,
|
383 |
+
queue=True,
|
384 |
)
|
385 |
with gr.TabItem("Retrieval Only", id=21):
|
386 |
with gr.Row():
|
387 |
with gr.Column(scale=1):
|
388 |
+
search_bar_ret = get_search_bar()
|
389 |
with gr.Column(scale=1):
|
390 |
+
models_ret = get_noreranking_dropdown()
|
391 |
+
|
392 |
+
_doc_df_ret = datastore.doc_fmt_df[
|
393 |
+
datastore.doc_fmt_df[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
394 |
]
|
395 |
+
_doc_df_ret = reset_rank(_doc_df_ret)
|
396 |
+
doc_df_elem_ret = get_leaderboard_table(_doc_df_ret, datastore.doc_types)
|
397 |
+
|
398 |
+
_doc_df_ret_hidden = datastore.doc_raw_df[
|
399 |
+
datastore.doc_raw_df[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
400 |
]
|
401 |
+
_doc_df_ret_hidden = reset_rank(_doc_df_ret_hidden)
|
402 |
+
doc_df_elem_ret_hidden = get_leaderboard_table(
|
403 |
+
_doc_df_ret_hidden, datastore.doc_types, visible=False
|
404 |
+
)
|
405 |
+
|
406 |
+
version.change(
|
407 |
+
update_doc_version,
|
408 |
+
version,
|
409 |
+
[domains, langs, models_ret, doc_df_elem_ret, doc_df_elem_ret_hidden],
|
410 |
)
|
411 |
|
412 |
set_listeners(
|
413 |
+
TaskType.long_doc,
|
414 |
+
doc_df_elem_ret,
|
415 |
+
doc_df_elem_ret_hidden,
|
416 |
+
search_bar_ret,
|
417 |
+
version,
|
418 |
+
domains,
|
419 |
+
langs,
|
420 |
+
models_ret,
|
421 |
show_anonymous,
|
422 |
+
show_rev_ts,
|
423 |
)
|
424 |
|
425 |
+
metric.change(
|
426 |
+
update_doc_metric,
|
427 |
[
|
428 |
+
metric,
|
429 |
+
domains,
|
430 |
+
langs,
|
431 |
+
models_ret,
|
432 |
+
search_bar_ret,
|
433 |
show_anonymous,
|
434 |
+
show_rev_ts,
|
435 |
],
|
436 |
+
doc_df_elem_ret,
|
437 |
+
queue=True,
|
438 |
)
|
439 |
with gr.TabItem("Reranking Only", id=22):
|
440 |
+
_doc_df_rerank = datastore.doc_fmt_df[
|
441 |
+
datastore.doc_fmt_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
|
442 |
+
]
|
443 |
+
_doc_df_rerank = reset_rank(_doc_df_rerank)
|
444 |
+
doc_rerank_models = (
|
445 |
+
_doc_df_rerank[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
|
446 |
+
)
|
447 |
with gr.Row():
|
448 |
with gr.Column(scale=1):
|
449 |
+
doc_models_rerank = get_reranking_dropdown(doc_rerank_models)
|
450 |
with gr.Column(scale=1):
|
451 |
+
doc_search_bar_rerank = gr.Textbox(show_label=False, visible=False)
|
452 |
+
doc_df_elem_rerank = get_leaderboard_table(_doc_df_rerank, datastore.doc_types)
|
453 |
+
_doc_df_rerank_hidden = datastore.doc_raw_df[
|
454 |
+
datastore.doc_raw_df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK
|
455 |
+
]
|
456 |
+
_doc_df_rerank_hidden = reset_rank(_doc_df_rerank_hidden)
|
457 |
+
doc_df_elem_rerank_hidden = get_leaderboard_table(
|
458 |
+
_doc_df_rerank_hidden, datastore.doc_types, visible=False
|
459 |
+
)
|
460 |
+
|
461 |
+
version.change(
|
462 |
+
update_doc_version,
|
463 |
+
version,
|
464 |
+
[domains, langs, doc_models_rerank, doc_df_elem_rerank, doc_df_elem_rerank_hidden],
|
465 |
)
|
466 |
|
467 |
set_listeners(
|
468 |
+
TaskType.long_doc,
|
469 |
+
doc_df_elem_rerank,
|
470 |
+
doc_df_elem_rerank_hidden,
|
471 |
+
doc_search_bar_rerank,
|
472 |
+
version,
|
473 |
+
domains,
|
474 |
+
langs,
|
475 |
+
doc_models_rerank,
|
476 |
show_anonymous,
|
477 |
+
show_rev_ts,
|
478 |
)
|
479 |
+
|
480 |
+
metric.change(
|
481 |
+
update_doc_metric,
|
482 |
[
|
483 |
+
metric,
|
484 |
+
domains,
|
485 |
+
langs,
|
486 |
+
doc_models_rerank,
|
487 |
+
doc_search_bar_rerank,
|
488 |
show_anonymous,
|
489 |
+
show_rev_ts,
|
490 |
],
|
491 |
+
doc_df_elem_rerank,
|
492 |
+
queue=True,
|
493 |
)
|
494 |
|
495 |
with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
|
|
|
506 |
with gr.Row():
|
507 |
with gr.Column():
|
508 |
reranking_model_name = gr.Textbox(
|
509 |
+
label="Reranking Model name", info="Optional", value="NoReranker"
|
|
|
|
|
510 |
)
|
511 |
with gr.Column():
|
512 |
+
reranking_model_url = gr.Textbox(label="Reranking Model URL", info="Optional", value="")
|
|
|
|
|
|
|
|
|
513 |
with gr.Row():
|
514 |
with gr.Column():
|
515 |
benchmark_version = gr.Dropdown(
|
516 |
BENCHMARK_VERSION_LIST,
|
517 |
value=LATEST_BENCHMARK_VERSION,
|
518 |
interactive=True,
|
519 |
+
label="AIR-Bench Version",
|
520 |
+
)
|
521 |
with gr.Row():
|
522 |
upload_button = gr.UploadButton("Click to upload search results", file_count="single")
|
523 |
with gr.Row():
|
|
|
526 |
is_anonymous = gr.Checkbox(
|
527 |
label="Nope. I want to submit anonymously 🥷",
|
528 |
value=False,
|
529 |
+
info="Do you want to shown on the leaderboard by default?",
|
530 |
+
)
|
531 |
with gr.Row():
|
532 |
submit_button = gr.Button("Submit")
|
533 |
with gr.Row():
|
|
|
537 |
[
|
538 |
upload_button,
|
539 |
],
|
540 |
+
file_output,
|
541 |
+
)
|
542 |
submit_button.click(
|
543 |
submit_results,
|
544 |
[
|
|
|
548 |
reranking_model_name,
|
549 |
reranking_model_url,
|
550 |
benchmark_version,
|
551 |
+
is_anonymous,
|
552 |
],
|
553 |
submission_result,
|
554 |
+
show_progress="hidden",
|
555 |
)
|
556 |
|
557 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
|
pyproject.toml
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
[tool.ruff]
|
2 |
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
|
3 |
-
select = ["E", "F"]
|
4 |
-
ignore = ["E501"] # line too long (black is taking care of this)
|
5 |
line-length = 119
|
6 |
-
fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
|
7 |
|
8 |
[tool.isort]
|
9 |
profile = "black"
|
|
|
1 |
[tool.ruff]
|
2 |
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
|
3 |
+
lint.select = ["E", "F"]
|
4 |
+
lint.ignore = ["E501"] # line too long (black is taking care of this)
|
5 |
line-length = 119
|
6 |
+
lint.fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
|
7 |
|
8 |
[tool.isort]
|
9 |
profile = "black"
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@ APScheduler>=3.10.1
|
|
2 |
black>=23.11.0
|
3 |
click>=8.1.3
|
4 |
datasets>=2.14.5
|
5 |
-
gradio
|
6 |
gradio_client>=0.16.1
|
7 |
huggingface-hub>=0.18.0
|
8 |
numpy>=1.24.2
|
@@ -12,4 +12,4 @@ requests>=2.31.0
|
|
12 |
tqdm>=4.65.0
|
13 |
accelerate>=0.24.1
|
14 |
socksio>=1.0.0
|
15 |
-
air-benchmark>=0.0
|
|
|
2 |
black>=23.11.0
|
3 |
click>=8.1.3
|
4 |
datasets>=2.14.5
|
5 |
+
gradio<5.0.0
|
6 |
gradio_client>=0.16.1
|
7 |
huggingface-hub>=0.18.0
|
8 |
numpy>=1.24.2
|
|
|
12 |
tqdm>=4.65.0
|
13 |
accelerate>=0.24.1
|
14 |
socksio>=1.0.0
|
15 |
+
air-benchmark>=0.1.0
|
src/about.py
CHANGED
@@ -8,7 +8,7 @@ INTRODUCTION_TEXT = """
|
|
8 |
"""
|
9 |
|
10 |
# Which evaluations are you running? how can people reproduce what you have?
|
11 |
-
BENCHMARKS_TEXT =
|
12 |
## How the test data are generated?
|
13 |
### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
|
14 |
|
|
|
8 |
"""
|
9 |
|
10 |
# Which evaluations are you running? how can people reproduce what you have?
|
11 |
+
BENCHMARKS_TEXT = """
|
12 |
## How the test data are generated?
|
13 |
### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
|
14 |
|
src/benchmarks.py
CHANGED
@@ -1,92 +1,71 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from enum import Enum
|
3 |
-
from air_benchmark.tasks.tasks import BenchmarkTable
|
4 |
-
|
5 |
-
|
6 |
-
def get_safe_name(name: str):
|
7 |
-
"""Get RFC 1123 compatible safe name"""
|
8 |
-
name = name.replace('-', '_')
|
9 |
-
return ''.join(
|
10 |
-
character.lower()
|
11 |
-
for character in name
|
12 |
-
if (character.isalnum() or character == '_'))
|
13 |
|
|
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
"ndcg_at_3",
|
18 |
-
"ndcg_at_5",
|
19 |
-
"ndcg_at_10",
|
20 |
-
"ndcg_at_100",
|
21 |
-
"ndcg_at_1000",
|
22 |
-
"map_at_1",
|
23 |
-
"map_at_3",
|
24 |
-
"map_at_5",
|
25 |
-
"map_at_10",
|
26 |
-
"map_at_100",
|
27 |
-
"map_at_1000",
|
28 |
-
"recall_at_1",
|
29 |
-
"recall_at_3",
|
30 |
-
"recall_at_5",
|
31 |
-
"recall_at_10",
|
32 |
-
"recall_at_100",
|
33 |
-
"recall_at_1000",
|
34 |
-
"precision_at_1",
|
35 |
-
"precision_at_3",
|
36 |
-
"precision_at_5",
|
37 |
-
"precision_at_10",
|
38 |
-
"precision_at_100",
|
39 |
-
"precision_at_1000",
|
40 |
-
"mrr_at_1",
|
41 |
-
"mrr_at_3",
|
42 |
-
"mrr_at_5",
|
43 |
-
"mrr_at_10",
|
44 |
-
"mrr_at_100",
|
45 |
-
"mrr_at_1000"
|
46 |
-
]
|
47 |
|
48 |
|
49 |
@dataclass
|
50 |
class Benchmark:
|
51 |
name: str # [domain]_[language]_[metric], task_key in the json file,
|
52 |
-
metric: str #
|
53 |
col_name: str # [domain]_[language], name to display in the leaderboard
|
54 |
domain: str
|
55 |
lang: str
|
56 |
task: str
|
57 |
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
for
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
67 |
col_name = benchmark_name
|
68 |
for metric in dataset_list:
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
for dataset in dataset_list:
|
72 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
73 |
benchmark_name = get_safe_name(benchmark_name)
|
74 |
col_name = benchmark_name
|
|
|
|
|
75 |
for metric in METRIC_LIST:
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
|
79 |
-
BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
|
80 |
-
BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)
|
81 |
|
82 |
-
|
83 |
-
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in long_doc_benchmark_dict.values()]))
|
89 |
-
LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in long_doc_benchmark_dict.values()]))
|
90 |
|
91 |
-
|
92 |
-
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from enum import Enum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
from air_benchmark.tasks.tasks import BenchmarkTable
|
5 |
|
6 |
+
from src.envs import BENCHMARK_VERSION_LIST, METRIC_LIST
|
7 |
+
from src.models import TaskType, get_safe_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
@dataclass
|
11 |
class Benchmark:
|
12 |
name: str # [domain]_[language]_[metric], task_key in the json file,
|
13 |
+
metric: str # metric_key in the json file
|
14 |
col_name: str # [domain]_[language], name to display in the leaderboard
|
15 |
domain: str
|
16 |
lang: str
|
17 |
task: str
|
18 |
|
19 |
|
20 |
+
# create a function return an enum class containing all the benchmarks
|
21 |
+
def get_qa_benchmarks_dict(version: str):
|
22 |
+
benchmark_dict = {}
|
23 |
+
for task, domain_dict in BenchmarkTable[version].items():
|
24 |
+
if task != TaskType.qa.value:
|
25 |
+
continue
|
26 |
+
for domain, lang_dict in domain_dict.items():
|
27 |
+
for lang, dataset_list in lang_dict.items():
|
28 |
+
benchmark_name = get_safe_name(f"{domain}_{lang}")
|
29 |
col_name = benchmark_name
|
30 |
for metric in dataset_list:
|
31 |
+
if "test" not in dataset_list[metric]["splits"]:
|
32 |
+
continue
|
33 |
+
benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
|
34 |
+
return benchmark_dict
|
35 |
+
|
36 |
+
|
37 |
+
def get_doc_benchmarks_dict(version: str):
|
38 |
+
benchmark_dict = {}
|
39 |
+
for task, domain_dict in BenchmarkTable[version].items():
|
40 |
+
if task != TaskType.long_doc.value:
|
41 |
+
continue
|
42 |
+
for domain, lang_dict in domain_dict.items():
|
43 |
+
for lang, dataset_list in lang_dict.items():
|
44 |
for dataset in dataset_list:
|
45 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
46 |
benchmark_name = get_safe_name(benchmark_name)
|
47 |
col_name = benchmark_name
|
48 |
+
if "test" not in dataset_list[dataset]["splits"]:
|
49 |
+
continue
|
50 |
for metric in METRIC_LIST:
|
51 |
+
benchmark_dict[benchmark_name] = Benchmark(
|
52 |
+
benchmark_name, metric, col_name, domain, lang, task
|
53 |
+
)
|
54 |
+
return benchmark_dict
|
55 |
|
|
|
|
|
56 |
|
57 |
+
_qa_benchmark_dict = {}
|
58 |
+
for version in BENCHMARK_VERSION_LIST:
|
59 |
+
safe_version_name = get_safe_name(version)
|
60 |
+
_qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_qa_benchmarks_dict(version))
|
61 |
|
62 |
+
_doc_benchmark_dict = {}
|
63 |
+
for version in BENCHMARK_VERSION_LIST:
|
64 |
+
safe_version_name = get_safe_name(version)
|
65 |
+
_doc_benchmark_dict[safe_version_name] = Enum(
|
66 |
+
f"LongDocBenchmarks_{safe_version_name}", get_doc_benchmarks_dict(version)
|
67 |
+
)
|
68 |
|
|
|
|
|
69 |
|
70 |
+
QABenchmarks = Enum("QABenchmarks", _qa_benchmark_dict)
|
71 |
+
LongDocBenchmarks = Enum("LongDocBenchmarks", _doc_benchmark_dict)
|
src/{display/utils.py → columns.py}
RENAMED
@@ -1,9 +1,7 @@
|
|
1 |
from dataclasses import dataclass, make_dataclass
|
2 |
|
3 |
-
from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
|
4 |
|
5 |
-
|
6 |
-
def fields(raw_class):
|
7 |
return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
|
8 |
|
9 |
|
@@ -19,28 +17,22 @@ class ColumnContent:
|
|
19 |
never_hidden: bool = False
|
20 |
|
21 |
|
22 |
-
COL_NAME_AVG = "Average ⬆️"
|
23 |
-
COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
|
24 |
-
COL_NAME_RERANKING_MODEL = "Reranking Model"
|
25 |
-
COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
|
26 |
-
COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
|
27 |
-
COL_NAME_RANK = "Rank 🏆"
|
28 |
-
COL_NAME_REVISION = "Revision"
|
29 |
-
COL_NAME_TIMESTAMP = "Submission Date"
|
30 |
-
COL_NAME_IS_ANONYMOUS = "Anonymous Submission"
|
31 |
-
|
32 |
-
|
33 |
def get_default_auto_eval_column_dict():
|
34 |
auto_eval_column_dict = []
|
35 |
-
|
36 |
auto_eval_column_dict.append(
|
37 |
-
[
|
|
|
|
|
|
|
|
|
38 |
)
|
39 |
auto_eval_column_dict.append(
|
40 |
-
[
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
44 |
)
|
45 |
auto_eval_column_dict.append(
|
46 |
["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
|
@@ -48,14 +40,30 @@ def get_default_auto_eval_column_dict():
|
|
48 |
auto_eval_column_dict.append(
|
49 |
["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
|
50 |
)
|
|
|
51 |
auto_eval_column_dict.append(
|
52 |
-
[
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
)
|
57 |
auto_eval_column_dict.append(
|
58 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
)
|
60 |
auto_eval_column_dict.append(
|
61 |
["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
|
@@ -63,10 +71,10 @@ def get_default_auto_eval_column_dict():
|
|
63 |
return auto_eval_column_dict
|
64 |
|
65 |
|
66 |
-
def make_autoevalcolumn(cls_name
|
67 |
auto_eval_column_dict = get_default_auto_eval_column_dict()
|
68 |
-
|
69 |
-
for benchmark in benchmarks:
|
70 |
auto_eval_column_dict.append(
|
71 |
[benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
|
72 |
)
|
@@ -75,19 +83,24 @@ def make_autoevalcolumn(cls_name="BenchmarksQA", benchmarks=BenchmarksQA):
|
|
75 |
return make_dataclass(cls_name, auto_eval_column_dict, frozen=True)
|
76 |
|
77 |
|
78 |
-
|
79 |
-
"
|
80 |
-
|
81 |
-
|
|
|
82 |
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
TYPES_QA = [c.type for c in fields(AutoEvalColumnQA) if not c.hidden]
|
88 |
-
TYPES_LONG_DOC = [c.type for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
|
89 |
-
COLS_LITE = [c.name for c in fields(AutoEvalColumnQA) if c.displayed_by_default and not c.hidden]
|
90 |
|
91 |
-
QA_BENCHMARK_COLS = [t.value.col_name for t in BenchmarksQA]
|
92 |
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from dataclasses import dataclass, make_dataclass
|
2 |
|
|
|
3 |
|
4 |
+
def _fields(raw_class):
|
|
|
5 |
return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
|
6 |
|
7 |
|
|
|
17 |
never_hidden: bool = False
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def get_default_auto_eval_column_dict():
|
21 |
auto_eval_column_dict = []
|
22 |
+
auto_eval_column_dict.append(["rank", ColumnContent, ColumnContent(COL_NAME_RANK, "number", True)])
|
23 |
auto_eval_column_dict.append(
|
24 |
+
[
|
25 |
+
"retrieval_model",
|
26 |
+
ColumnContent,
|
27 |
+
ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", True, never_hidden=True),
|
28 |
+
]
|
29 |
)
|
30 |
auto_eval_column_dict.append(
|
31 |
+
[
|
32 |
+
"reranking_model",
|
33 |
+
ColumnContent,
|
34 |
+
ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, never_hidden=True),
|
35 |
+
]
|
36 |
)
|
37 |
auto_eval_column_dict.append(
|
38 |
["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
|
|
|
40 |
auto_eval_column_dict.append(
|
41 |
["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
|
42 |
)
|
43 |
+
auto_eval_column_dict.append(["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)])
|
44 |
auto_eval_column_dict.append(
|
45 |
+
[
|
46 |
+
"retrieval_model_link",
|
47 |
+
ColumnContent,
|
48 |
+
ColumnContent(
|
49 |
+
COL_NAME_RETRIEVAL_MODEL_LINK,
|
50 |
+
"markdown",
|
51 |
+
False,
|
52 |
+
hidden=True,
|
53 |
+
),
|
54 |
+
]
|
55 |
)
|
56 |
auto_eval_column_dict.append(
|
57 |
+
[
|
58 |
+
"reranking_model_link",
|
59 |
+
ColumnContent,
|
60 |
+
ColumnContent(
|
61 |
+
COL_NAME_RERANKING_MODEL_LINK,
|
62 |
+
"markdown",
|
63 |
+
False,
|
64 |
+
hidden=True,
|
65 |
+
),
|
66 |
+
]
|
67 |
)
|
68 |
auto_eval_column_dict.append(
|
69 |
["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
|
|
|
71 |
return auto_eval_column_dict
|
72 |
|
73 |
|
74 |
+
def make_autoevalcolumn(cls_name, benchmarks):
|
75 |
auto_eval_column_dict = get_default_auto_eval_column_dict()
|
76 |
+
# Leaderboard columns
|
77 |
+
for benchmark in list(benchmarks.value):
|
78 |
auto_eval_column_dict.append(
|
79 |
[benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
|
80 |
)
|
|
|
83 |
return make_dataclass(cls_name, auto_eval_column_dict, frozen=True)
|
84 |
|
85 |
|
86 |
+
def get_default_col_names_and_types(benchmarks):
|
87 |
+
AutoEvalColumn = make_autoevalcolumn("AutoEvalColumn", benchmarks)
|
88 |
+
col_names = [c.name for c in _fields(AutoEvalColumn) if not c.hidden]
|
89 |
+
col_types = [c.type for c in _fields(AutoEvalColumn) if not c.hidden]
|
90 |
+
return col_names, col_types
|
91 |
|
92 |
|
93 |
+
def get_fixed_col_names_and_types():
|
94 |
+
fixed_cols = get_default_auto_eval_column_dict()[:-3]
|
95 |
+
return [c.name for _, _, c in fixed_cols], [c.type for _, _, c in fixed_cols]
|
|
|
|
|
|
|
96 |
|
|
|
97 |
|
98 |
+
COL_NAME_AVG = "Average ⬆️"
|
99 |
+
COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
|
100 |
+
COL_NAME_RERANKING_MODEL = "Reranking Model"
|
101 |
+
COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
|
102 |
+
COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
|
103 |
+
COL_NAME_RANK = "Rank 🏆"
|
104 |
+
COL_NAME_REVISION = "Revision"
|
105 |
+
COL_NAME_TIMESTAMP = "Submission Date"
|
106 |
+
COL_NAME_IS_ANONYMOUS = "Anonymous Submission"
|
src/{display/gradio_formatting.py → components.py}
RENAMED
@@ -1,12 +1,14 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
from src.envs import BENCHMARK_VERSION_LIST, LATEST_BENCHMARK_VERSION
|
3 |
|
|
|
4 |
def get_version_dropdown():
|
5 |
return gr.Dropdown(
|
6 |
choices=BENCHMARK_VERSION_LIST,
|
7 |
value=LATEST_BENCHMARK_VERSION,
|
8 |
label="Select the version of AIR-Bench",
|
9 |
-
interactive=True
|
10 |
)
|
11 |
|
12 |
|
@@ -14,26 +16,25 @@ def get_search_bar():
|
|
14 |
return gr.Textbox(
|
15 |
placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
|
16 |
show_label=False,
|
17 |
-
info="Search the retrieval methods"
|
18 |
)
|
19 |
|
20 |
|
21 |
def get_reranking_dropdown(model_list):
|
22 |
-
return gr.Dropdown(
|
23 |
-
choices=model_list,
|
24 |
-
label="Select the reranking models",
|
25 |
-
interactive=True,
|
26 |
-
multiselect=True
|
27 |
-
)
|
28 |
|
29 |
|
30 |
def get_noreranking_dropdown():
|
31 |
return gr.Dropdown(
|
32 |
-
choices=[
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
interactive=False,
|
35 |
multiselect=True,
|
36 |
-
visible=False
|
37 |
)
|
38 |
|
39 |
|
@@ -52,7 +53,10 @@ def get_metric_dropdown(metric_list, default_metrics):
|
|
52 |
)
|
53 |
|
54 |
|
55 |
-
def get_domain_dropdown(
|
|
|
|
|
|
|
56 |
return gr.CheckboxGroup(
|
57 |
choices=domain_list,
|
58 |
value=default_domains,
|
@@ -61,13 +65,16 @@ def get_domain_dropdown(domain_list, default_domains):
|
|
61 |
)
|
62 |
|
63 |
|
64 |
-
def get_language_dropdown(
|
|
|
|
|
|
|
65 |
return gr.Dropdown(
|
66 |
choices=language_list,
|
67 |
-
value=
|
68 |
label="Select the languages",
|
69 |
multiselect=True,
|
70 |
-
interactive=True
|
71 |
)
|
72 |
|
73 |
|
@@ -75,15 +82,13 @@ def get_anonymous_checkbox():
|
|
75 |
return gr.Checkbox(
|
76 |
label="Show anonymous submissions",
|
77 |
value=False,
|
78 |
-
info="The anonymous submissions might have invalid model information."
|
79 |
)
|
80 |
|
81 |
|
82 |
def get_revision_and_ts_checkbox():
|
83 |
return gr.Checkbox(
|
84 |
-
label="Show submission details",
|
85 |
-
value=False,
|
86 |
-
info="Show the revision and timestamp information of submissions"
|
87 |
)
|
88 |
|
89 |
|
|
|
1 |
import gradio as gr
|
2 |
+
|
3 |
from src.envs import BENCHMARK_VERSION_LIST, LATEST_BENCHMARK_VERSION
|
4 |
|
5 |
+
|
6 |
def get_version_dropdown():
|
7 |
return gr.Dropdown(
|
8 |
choices=BENCHMARK_VERSION_LIST,
|
9 |
value=LATEST_BENCHMARK_VERSION,
|
10 |
label="Select the version of AIR-Bench",
|
11 |
+
interactive=True,
|
12 |
)
|
13 |
|
14 |
|
|
|
16 |
return gr.Textbox(
|
17 |
placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
|
18 |
show_label=False,
|
19 |
+
info="Search the retrieval methods",
|
20 |
)
|
21 |
|
22 |
|
23 |
def get_reranking_dropdown(model_list):
|
24 |
+
return gr.Dropdown(choices=model_list, label="Select the reranking models", interactive=True, multiselect=True)
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
def get_noreranking_dropdown():
|
28 |
return gr.Dropdown(
|
29 |
+
choices=[
|
30 |
+
"NoReranker",
|
31 |
+
],
|
32 |
+
value=[
|
33 |
+
"NoReranker",
|
34 |
+
],
|
35 |
interactive=False,
|
36 |
multiselect=True,
|
37 |
+
visible=False,
|
38 |
)
|
39 |
|
40 |
|
|
|
53 |
)
|
54 |
|
55 |
|
56 |
+
def get_domain_dropdown(benchmarks, default_domains=None):
|
57 |
+
domain_list = list(frozenset([c.value.domain for c in list(benchmarks.value)]))
|
58 |
+
if default_domains is None:
|
59 |
+
default_domains = domain_list
|
60 |
return gr.CheckboxGroup(
|
61 |
choices=domain_list,
|
62 |
value=default_domains,
|
|
|
65 |
)
|
66 |
|
67 |
|
68 |
+
def get_language_dropdown(benchmarks, default_languages=None):
|
69 |
+
language_list = list(frozenset([c.value.lang for c in list(benchmarks.value)]))
|
70 |
+
if default_languages is None:
|
71 |
+
default_languages = language_list
|
72 |
return gr.Dropdown(
|
73 |
choices=language_list,
|
74 |
+
value=default_languages,
|
75 |
label="Select the languages",
|
76 |
multiselect=True,
|
77 |
+
interactive=True,
|
78 |
)
|
79 |
|
80 |
|
|
|
82 |
return gr.Checkbox(
|
83 |
label="Show anonymous submissions",
|
84 |
value=False,
|
85 |
+
info="The anonymous submissions might have invalid model information.",
|
86 |
)
|
87 |
|
88 |
|
89 |
def get_revision_and_ts_checkbox():
|
90 |
return gr.Checkbox(
|
91 |
+
label="Show submission details", value=False, info="Show the revision and timestamp information of submissions"
|
|
|
|
|
92 |
)
|
93 |
|
94 |
|
src/{display/css_html_js.py → css_html_js.py}
RENAMED
File without changes
|
src/display/formatting.py
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
def model_hyperlink(link, model_name):
|
2 |
-
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
3 |
-
|
4 |
-
|
5 |
-
def make_clickable_model(model_name: str, model_link: str):
|
6 |
-
# link = f"https://huggingface.co/{model_name}"
|
7 |
-
if not model_link or not model_link.startswith("https://"):
|
8 |
-
return model_name
|
9 |
-
return model_hyperlink(model_link, model_name)
|
10 |
-
|
11 |
-
|
12 |
-
def styled_error(error):
|
13 |
-
return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
|
14 |
-
|
15 |
-
|
16 |
-
def styled_warning(warn):
|
17 |
-
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
|
18 |
-
|
19 |
-
|
20 |
-
def styled_message(message):
|
21 |
-
return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
|
22 |
-
|
23 |
-
|
24 |
-
def has_no_nan_values(df, columns):
|
25 |
-
return df[columns].notna().all(axis=1)
|
26 |
-
|
27 |
-
|
28 |
-
def has_nan_values(df, columns):
|
29 |
-
return df[columns].isna().any(axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/display/gradio_listener.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
from src.utils import update_table, update_table_long_doc
|
2 |
-
|
3 |
-
|
4 |
-
def set_listeners(
|
5 |
-
task,
|
6 |
-
displayed_leaderboard,
|
7 |
-
hidden_leaderboard,
|
8 |
-
search_bar,
|
9 |
-
selected_domains,
|
10 |
-
selected_langs,
|
11 |
-
selected_rerankings,
|
12 |
-
show_anonymous,
|
13 |
-
show_revision_and_timestamp,
|
14 |
-
|
15 |
-
):
|
16 |
-
if task == "qa":
|
17 |
-
update_table_func = update_table
|
18 |
-
elif task == "long-doc":
|
19 |
-
update_table_func = update_table_long_doc
|
20 |
-
else:
|
21 |
-
raise NotImplementedError
|
22 |
-
# Set search_bar listener
|
23 |
-
search_bar.submit(
|
24 |
-
update_table_func,
|
25 |
-
[
|
26 |
-
hidden_leaderboard, # hidden_leaderboard_table_for_search,
|
27 |
-
selected_domains,
|
28 |
-
selected_langs,
|
29 |
-
selected_rerankings,
|
30 |
-
search_bar,
|
31 |
-
show_anonymous,
|
32 |
-
],
|
33 |
-
displayed_leaderboard
|
34 |
-
)
|
35 |
-
|
36 |
-
# Set column-wise listener
|
37 |
-
for selector in [
|
38 |
-
selected_domains, selected_langs, show_anonymous, show_revision_and_timestamp, selected_rerankings
|
39 |
-
]:
|
40 |
-
selector.change(
|
41 |
-
update_table_func,
|
42 |
-
[
|
43 |
-
hidden_leaderboard,
|
44 |
-
selected_domains,
|
45 |
-
selected_langs,
|
46 |
-
selected_rerankings,
|
47 |
-
search_bar,
|
48 |
-
show_anonymous,
|
49 |
-
show_revision_and_timestamp
|
50 |
-
],
|
51 |
-
displayed_leaderboard,
|
52 |
-
queue=True,
|
53 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/envs.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
import os
|
2 |
-
|
3 |
from huggingface_hub import HfApi
|
4 |
|
5 |
# Info to change for your repository
|
6 |
# ----------------------------------
|
7 |
TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
|
8 |
|
9 |
-
OWNER =
|
|
|
|
|
10 |
# ----------------------------------
|
11 |
|
12 |
REPO_ID = f"{OWNER}/leaderboard"
|
@@ -15,7 +17,7 @@ RESULTS_REPO = f"{OWNER}/eval_results"
|
|
15 |
# repo for submitting the evaluation
|
16 |
SEARCH_RESULTS_REPO = f"{OWNER}/search_results"
|
17 |
|
18 |
-
# If you
|
19 |
CACHE_PATH = os.getenv("HF_HOME", ".")
|
20 |
|
21 |
# Local caches
|
@@ -23,11 +25,43 @@ EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval_results")
|
|
23 |
|
24 |
API = HfApi(token=TOKEN)
|
25 |
|
26 |
-
BM25_LINK = model_hyperlink("https://github.com/castorini/pyserini", "BM25")
|
27 |
-
|
28 |
BENCHMARK_VERSION_LIST = [
|
29 |
"AIR-Bench_24.04",
|
30 |
"AIR-Bench_24.05",
|
31 |
]
|
32 |
|
33 |
-
LATEST_BENCHMARK_VERSION = BENCHMARK_VERSION_LIST[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
|
3 |
from huggingface_hub import HfApi
|
4 |
|
5 |
# Info to change for your repository
|
6 |
# ----------------------------------
|
7 |
TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
|
8 |
|
9 |
+
OWNER = (
|
10 |
+
"AIR-Bench" # Change to your org - don't forget to create a results and request dataset, with the correct format!
|
11 |
+
)
|
12 |
# ----------------------------------
|
13 |
|
14 |
REPO_ID = f"{OWNER}/leaderboard"
|
|
|
17 |
# repo for submitting the evaluation
|
18 |
SEARCH_RESULTS_REPO = f"{OWNER}/search_results"
|
19 |
|
20 |
+
# If you set up a cache later, just change HF_HOME
|
21 |
CACHE_PATH = os.getenv("HF_HOME", ".")
|
22 |
|
23 |
# Local caches
|
|
|
25 |
|
26 |
API = HfApi(token=TOKEN)
|
27 |
|
|
|
|
|
28 |
BENCHMARK_VERSION_LIST = [
|
29 |
"AIR-Bench_24.04",
|
30 |
"AIR-Bench_24.05",
|
31 |
]
|
32 |
|
33 |
+
LATEST_BENCHMARK_VERSION = BENCHMARK_VERSION_LIST[0]
|
34 |
+
DEFAULT_METRIC_QA = "ndcg_at_10"
|
35 |
+
DEFAULT_METRIC_LONG_DOC = "recall_at_10"
|
36 |
+
METRIC_LIST = [
|
37 |
+
"ndcg_at_1",
|
38 |
+
"ndcg_at_3",
|
39 |
+
"ndcg_at_5",
|
40 |
+
"ndcg_at_10",
|
41 |
+
"ndcg_at_100",
|
42 |
+
"ndcg_at_1000",
|
43 |
+
"map_at_1",
|
44 |
+
"map_at_3",
|
45 |
+
"map_at_5",
|
46 |
+
"map_at_10",
|
47 |
+
"map_at_100",
|
48 |
+
"map_at_1000",
|
49 |
+
"recall_at_1",
|
50 |
+
"recall_at_3",
|
51 |
+
"recall_at_5",
|
52 |
+
"recall_at_10",
|
53 |
+
"recall_at_100",
|
54 |
+
"recall_at_1000",
|
55 |
+
"precision_at_1",
|
56 |
+
"precision_at_3",
|
57 |
+
"precision_at_5",
|
58 |
+
"precision_at_10",
|
59 |
+
"precision_at_100",
|
60 |
+
"precision_at_1000",
|
61 |
+
"mrr_at_1",
|
62 |
+
"mrr_at_3",
|
63 |
+
"mrr_at_5",
|
64 |
+
"mrr_at_10",
|
65 |
+
"mrr_at_100",
|
66 |
+
"mrr_at_1000",
|
67 |
+
]
|
src/loaders.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict, List, Union
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
from src.columns import COL_NAME_IS_ANONYMOUS, COL_NAME_REVISION, COL_NAME_TIMESTAMP
|
8 |
+
from src.envs import BENCHMARK_VERSION_LIST, DEFAULT_METRIC_LONG_DOC, DEFAULT_METRIC_QA
|
9 |
+
from src.models import FullEvalResult, LeaderboardDataStore, TaskType, get_safe_name
|
10 |
+
from src.utils import get_default_cols, get_leaderboard_df, reset_rank
|
11 |
+
|
12 |
+
pd.options.mode.copy_on_write = True
|
13 |
+
|
14 |
+
|
15 |
+
def load_raw_eval_results(results_path: Union[Path, str]) -> List[FullEvalResult]:
|
16 |
+
"""
|
17 |
+
Load the evaluation results from a json file
|
18 |
+
"""
|
19 |
+
model_result_filepaths = []
|
20 |
+
for root, dirs, files in os.walk(results_path):
|
21 |
+
if len(files) == 0:
|
22 |
+
continue
|
23 |
+
|
24 |
+
# select the latest results
|
25 |
+
for file in files:
|
26 |
+
if not (file.startswith("results") and file.endswith(".json")):
|
27 |
+
print(f"skip {file}")
|
28 |
+
continue
|
29 |
+
model_result_filepaths.append(os.path.join(root, file))
|
30 |
+
|
31 |
+
eval_results = {}
|
32 |
+
for model_result_filepath in model_result_filepaths:
|
33 |
+
# create evaluation results
|
34 |
+
try:
|
35 |
+
eval_result = FullEvalResult.init_from_json_file(model_result_filepath)
|
36 |
+
except UnicodeDecodeError:
|
37 |
+
print(f"loading file failed. {model_result_filepath}")
|
38 |
+
continue
|
39 |
+
print(f"file loaded: {model_result_filepath}")
|
40 |
+
timestamp = eval_result.timestamp
|
41 |
+
eval_results[timestamp] = eval_result
|
42 |
+
|
43 |
+
results = []
|
44 |
+
for k, v in eval_results.items():
|
45 |
+
try:
|
46 |
+
v.to_dict()
|
47 |
+
results.append(v)
|
48 |
+
except KeyError:
|
49 |
+
print(f"loading failed: {k}")
|
50 |
+
continue
|
51 |
+
return results
|
52 |
+
|
53 |
+
|
54 |
+
def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
|
55 |
+
ds = LeaderboardDataStore(version, get_safe_name(version))
|
56 |
+
ds.raw_data = load_raw_eval_results(file_path)
|
57 |
+
print(f"raw data: {len(ds.raw_data)}")
|
58 |
+
|
59 |
+
ds.qa_raw_df = get_leaderboard_df(ds, TaskType.qa, DEFAULT_METRIC_QA)
|
60 |
+
print(f"QA data loaded: {ds.qa_raw_df.shape}")
|
61 |
+
ds.qa_fmt_df = ds.qa_raw_df.copy()
|
62 |
+
qa_cols, ds.qa_types = get_default_cols(TaskType.qa, ds.slug, add_fix_cols=True)
|
63 |
+
# by default, drop the anonymous submissions
|
64 |
+
ds.qa_fmt_df = ds.qa_fmt_df[~ds.qa_fmt_df[COL_NAME_IS_ANONYMOUS]][qa_cols]
|
65 |
+
# reset the rank after dropping the anonymous submissions
|
66 |
+
ds.qa_fmt_df = reset_rank(ds.qa_fmt_df)
|
67 |
+
ds.qa_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
68 |
+
|
69 |
+
ds.doc_raw_df = get_leaderboard_df(ds, TaskType.long_doc, DEFAULT_METRIC_LONG_DOC)
|
70 |
+
print(f"Long-Doc data loaded: {len(ds.doc_raw_df)}")
|
71 |
+
ds.doc_fmt_df = ds.doc_raw_df.copy()
|
72 |
+
doc_cols, ds.doc_types = get_default_cols(TaskType.long_doc, ds.slug, add_fix_cols=True)
|
73 |
+
# by default, drop the anonymous submissions
|
74 |
+
ds.doc_fmt_df = ds.doc_fmt_df[~ds.doc_fmt_df[COL_NAME_IS_ANONYMOUS]][doc_cols]
|
75 |
+
# reset the rank after dropping the anonymous submissions
|
76 |
+
ds.doc_fmt_df = reset_rank(ds.doc_fmt_df)
|
77 |
+
ds.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
78 |
+
|
79 |
+
ds.reranking_models = sorted(list(frozenset([eval_result.reranking_model for eval_result in ds.raw_data])))
|
80 |
+
return ds
|
81 |
+
|
82 |
+
|
83 |
+
def load_eval_results(file_path: Union[str, Path]) -> Dict[str, LeaderboardDataStore]:
|
84 |
+
output = {}
|
85 |
+
for version in BENCHMARK_VERSION_LIST:
|
86 |
+
fn = f"{file_path}/{version}"
|
87 |
+
output[version] = load_leaderboard_datastore(fn, version)
|
88 |
+
return output
|
src/{read_evals.py → models.py}
RENAMED
@@ -1,38 +1,21 @@
|
|
1 |
import json
|
2 |
-
import os.path
|
3 |
from collections import defaultdict
|
4 |
from dataclasses import dataclass
|
|
|
5 |
from typing import List
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
-
from src.
|
10 |
-
|
11 |
COL_NAME_RERANKING_MODEL,
|
12 |
-
COL_NAME_RETRIEVAL_MODEL,
|
13 |
COL_NAME_RERANKING_MODEL_LINK,
|
|
|
14 |
COL_NAME_RETRIEVAL_MODEL_LINK,
|
15 |
COL_NAME_REVISION,
|
16 |
COL_NAME_TIMESTAMP,
|
17 |
-
COL_NAME_IS_ANONYMOUS,
|
18 |
-
COLS_QA,
|
19 |
-
QA_BENCHMARK_COLS,
|
20 |
-
COLS_LONG_DOC,
|
21 |
-
LONG_DOC_BENCHMARK_COLS,
|
22 |
-
COL_NAME_AVG,
|
23 |
-
COL_NAME_RANK
|
24 |
)
|
25 |
|
26 |
-
from src.display.formatting import make_clickable_model
|
27 |
-
|
28 |
-
pd.options.mode.copy_on_write = True
|
29 |
-
|
30 |
-
def calculate_mean(row):
|
31 |
-
if pd.isna(row).any():
|
32 |
-
return -1
|
33 |
-
else:
|
34 |
-
return row.mean()
|
35 |
-
|
36 |
|
37 |
@dataclass
|
38 |
class EvalResult:
|
@@ -40,6 +23,7 @@ class EvalResult:
|
|
40 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
|
41 |
domains, languages, and datasets
|
42 |
"""
|
|
|
43 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
|
44 |
retrieval_model: str
|
45 |
reranking_model: str
|
@@ -56,6 +40,7 @@ class FullEvalResult:
|
|
56 |
"""
|
57 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
|
58 |
"""
|
|
|
59 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
|
60 |
retrieval_model: str
|
61 |
reranking_model: str
|
@@ -79,7 +64,6 @@ class FullEvalResult:
|
|
79 |
result_list = []
|
80 |
retrieval_model_link = ""
|
81 |
reranking_model_link = ""
|
82 |
-
revision = ""
|
83 |
for item in model_data:
|
84 |
config = item.get("config", {})
|
85 |
# eval results for different metrics
|
@@ -98,24 +82,26 @@ class FullEvalResult:
|
|
98 |
metric=config["metric"],
|
99 |
timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
|
100 |
revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
|
101 |
-
is_anonymous=config.get("is_anonymous", False)
|
102 |
)
|
103 |
result_list.append(eval_result)
|
|
|
104 |
return cls(
|
105 |
-
eval_name=f"{
|
106 |
-
retrieval_model=
|
107 |
-
reranking_model=
|
108 |
retrieval_model_link=retrieval_model_link,
|
109 |
reranking_model_link=reranking_model_link,
|
110 |
results=result_list,
|
111 |
-
timestamp=
|
112 |
-
revision=
|
113 |
-
is_anonymous=
|
114 |
)
|
115 |
|
116 |
-
def to_dict(self, task=
|
117 |
"""
|
118 |
-
Convert the results in all the EvalResults over different tasks and metrics.
|
|
|
119 |
"""
|
120 |
results = defaultdict(dict)
|
121 |
for eval_result in self.results:
|
@@ -123,106 +109,66 @@ class FullEvalResult:
|
|
123 |
continue
|
124 |
if eval_result.task != task:
|
125 |
continue
|
126 |
-
|
127 |
-
results[
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
results[
|
132 |
-
|
133 |
-
|
134 |
-
results[
|
135 |
-
results[
|
136 |
-
|
137 |
-
|
|
|
|
|
138 |
for result in eval_result.results:
|
139 |
# add result for each domain, language, and dataset
|
140 |
domain = result["domain"]
|
141 |
lang = result["lang"]
|
142 |
dataset = result["dataset"]
|
143 |
value = result["value"] * 100
|
144 |
-
if dataset ==
|
145 |
benchmark_name = f"{domain}_{lang}"
|
146 |
else:
|
147 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
148 |
-
results[
|
149 |
return [v for v in results.values()]
|
150 |
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
continue
|
188 |
-
return results
|
189 |
-
|
190 |
-
|
191 |
-
def get_leaderboard_df(raw_data: List[FullEvalResult], task: str, metric: str) -> pd.DataFrame:
|
192 |
-
"""
|
193 |
-
Creates a dataframe from all the individual experiment results
|
194 |
-
"""
|
195 |
-
cols = [COL_NAME_IS_ANONYMOUS, ]
|
196 |
-
if task == "qa":
|
197 |
-
cols += COLS_QA
|
198 |
-
benchmark_cols = QA_BENCHMARK_COLS
|
199 |
-
elif task == "long-doc":
|
200 |
-
cols += COLS_LONG_DOC
|
201 |
-
benchmark_cols = LONG_DOC_BENCHMARK_COLS
|
202 |
-
else:
|
203 |
-
raise NotImplemented
|
204 |
-
all_data_json = []
|
205 |
-
for v in raw_data:
|
206 |
-
all_data_json += v.to_dict(task=task, metric=metric)
|
207 |
-
df = pd.DataFrame.from_records(all_data_json)
|
208 |
-
# print(f'dataframe created: {df.shape}')
|
209 |
-
|
210 |
-
_benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
|
211 |
-
|
212 |
-
# calculate the average score for selected benchmarks
|
213 |
-
df[COL_NAME_AVG] = df[list(_benchmark_cols)].apply(calculate_mean, axis=1).round(decimals=2)
|
214 |
-
df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
|
215 |
-
df.reset_index(inplace=True, drop=True)
|
216 |
-
|
217 |
-
_cols = frozenset(cols).intersection(frozenset(df.columns.to_list()))
|
218 |
-
df = df[_cols].round(decimals=2)
|
219 |
-
|
220 |
-
# filter out if any of the benchmarks have not been produced
|
221 |
-
df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
|
222 |
-
|
223 |
-
# shorten the revision
|
224 |
-
df[COL_NAME_REVISION] = df[COL_NAME_REVISION].str[:6]
|
225 |
-
|
226 |
-
# # replace "0" with "-" for average score
|
227 |
-
# df[COL_NAME_AVG] = df[COL_NAME_AVG].replace(0, "-")
|
228 |
-
return df
|
|
|
1 |
import json
|
|
|
2 |
from collections import defaultdict
|
3 |
from dataclasses import dataclass
|
4 |
+
from enum import Enum
|
5 |
from typing import List
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
+
from src.columns import (
|
10 |
+
COL_NAME_IS_ANONYMOUS,
|
11 |
COL_NAME_RERANKING_MODEL,
|
|
|
12 |
COL_NAME_RERANKING_MODEL_LINK,
|
13 |
+
COL_NAME_RETRIEVAL_MODEL,
|
14 |
COL_NAME_RETRIEVAL_MODEL_LINK,
|
15 |
COL_NAME_REVISION,
|
16 |
COL_NAME_TIMESTAMP,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
@dataclass
|
21 |
class EvalResult:
|
|
|
23 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
|
24 |
domains, languages, and datasets
|
25 |
"""
|
26 |
+
|
27 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
|
28 |
retrieval_model: str
|
29 |
reranking_model: str
|
|
|
40 |
"""
|
41 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
|
42 |
"""
|
43 |
+
|
44 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
|
45 |
retrieval_model: str
|
46 |
reranking_model: str
|
|
|
64 |
result_list = []
|
65 |
retrieval_model_link = ""
|
66 |
reranking_model_link = ""
|
|
|
67 |
for item in model_data:
|
68 |
config = item.get("config", {})
|
69 |
# eval results for different metrics
|
|
|
82 |
metric=config["metric"],
|
83 |
timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
|
84 |
revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
|
85 |
+
is_anonymous=config.get("is_anonymous", False),
|
86 |
)
|
87 |
result_list.append(eval_result)
|
88 |
+
eval_result = result_list[0]
|
89 |
return cls(
|
90 |
+
eval_name=f"{eval_result.retrieval_model}_{eval_result.reranking_model}",
|
91 |
+
retrieval_model=eval_result.retrieval_model,
|
92 |
+
reranking_model=eval_result.reranking_model,
|
93 |
retrieval_model_link=retrieval_model_link,
|
94 |
reranking_model_link=reranking_model_link,
|
95 |
results=result_list,
|
96 |
+
timestamp=eval_result.timestamp,
|
97 |
+
revision=eval_result.revision,
|
98 |
+
is_anonymous=eval_result.is_anonymous,
|
99 |
)
|
100 |
|
101 |
+
def to_dict(self, task="qa", metric="ndcg_at_3") -> List:
|
102 |
"""
|
103 |
+
Convert the results in all the EvalResults over different tasks and metrics.
|
104 |
+
The output is a list of dict compatible with the dataframe UI
|
105 |
"""
|
106 |
results = defaultdict(dict)
|
107 |
for eval_result in self.results:
|
|
|
109 |
continue
|
110 |
if eval_result.task != task:
|
111 |
continue
|
112 |
+
eval_name = eval_result.eval_name
|
113 |
+
results[eval_name]["eval_name"] = eval_name
|
114 |
+
results[eval_name][COL_NAME_RETRIEVAL_MODEL] = make_clickable_model(
|
115 |
+
self.retrieval_model, self.retrieval_model_link
|
116 |
+
)
|
117 |
+
results[eval_name][COL_NAME_RERANKING_MODEL] = make_clickable_model(
|
118 |
+
self.reranking_model, self.reranking_model_link
|
119 |
+
)
|
120 |
+
results[eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
|
121 |
+
results[eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
|
122 |
+
results[eval_name][COL_NAME_REVISION] = self.revision
|
123 |
+
results[eval_name][COL_NAME_TIMESTAMP] = self.timestamp
|
124 |
+
results[eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
|
125 |
+
|
126 |
for result in eval_result.results:
|
127 |
# add result for each domain, language, and dataset
|
128 |
domain = result["domain"]
|
129 |
lang = result["lang"]
|
130 |
dataset = result["dataset"]
|
131 |
value = result["value"] * 100
|
132 |
+
if dataset == "default":
|
133 |
benchmark_name = f"{domain}_{lang}"
|
134 |
else:
|
135 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
136 |
+
results[eval_name][get_safe_name(benchmark_name)] = value
|
137 |
return [v for v in results.values()]
|
138 |
|
139 |
|
140 |
+
@dataclass
|
141 |
+
class LeaderboardDataStore:
|
142 |
+
version: str
|
143 |
+
slug: str
|
144 |
+
raw_data: list = None
|
145 |
+
qa_raw_df: pd.DataFrame = pd.DataFrame()
|
146 |
+
doc_raw_df: pd.DataFrame = pd.DataFrame()
|
147 |
+
qa_fmt_df: pd.DataFrame = pd.DataFrame()
|
148 |
+
doc_fmt_df: pd.DataFrame = pd.DataFrame()
|
149 |
+
reranking_models: list = None
|
150 |
+
qa_types: list = None
|
151 |
+
doc_types: list = None
|
152 |
+
|
153 |
+
|
154 |
+
# Define an enum class with the name `TaskType`. There are two types of tasks, `qa` and `long-doc`.
|
155 |
+
class TaskType(Enum):
|
156 |
+
qa = "qa"
|
157 |
+
long_doc = "long-doc"
|
158 |
+
|
159 |
+
|
160 |
+
def make_clickable_model(model_name: str, model_link: str):
|
161 |
+
# link = f"https://huggingface.co/{model_name}"
|
162 |
+
if not model_link or not model_link.startswith("https://"):
|
163 |
+
return model_name
|
164 |
+
return model_hyperlink(model_link, model_name)
|
165 |
+
|
166 |
+
|
167 |
+
def model_hyperlink(link, model_name):
|
168 |
+
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
169 |
+
|
170 |
+
|
171 |
+
def get_safe_name(name: str):
|
172 |
+
"""Get RFC 1123 compatible safe name"""
|
173 |
+
name = name.replace("-", "_")
|
174 |
+
return "".join(character.lower() for character in name if (character.isalnum() or character == "_"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils.py
CHANGED
@@ -1,24 +1,37 @@
|
|
1 |
-
import json
|
2 |
import hashlib
|
|
|
|
|
3 |
from datetime import datetime, timezone
|
4 |
from pathlib import Path
|
5 |
-
from typing import List
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
-
from src.benchmarks import
|
10 |
-
from src.
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def remove_html(input_str):
|
20 |
# Regular expression for finding HTML tags
|
21 |
-
clean = re.sub(r
|
22 |
return clean
|
23 |
|
24 |
|
@@ -55,160 +68,152 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
|
|
55 |
return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
|
56 |
|
57 |
|
58 |
-
def get_default_cols(task:
|
59 |
cols = []
|
60 |
types = []
|
61 |
-
if task ==
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
elif task == "long-doc":
|
66 |
-
cols_list = COLS_LONG_DOC
|
67 |
-
types_list = TYPES_LONG_DOC
|
68 |
-
benchmark_list = BENCHMARK_COLS_LONG_DOC
|
69 |
else:
|
70 |
-
raise
|
|
|
|
|
71 |
for col_name, col_type in zip(cols_list, types_list):
|
72 |
if col_name not in benchmark_list:
|
73 |
continue
|
74 |
-
if len(columns) > 0 and col_name not in columns:
|
75 |
-
continue
|
76 |
cols.append(col_name)
|
77 |
types.append(col_type)
|
78 |
-
|
79 |
if add_fix_cols:
|
80 |
_cols = []
|
81 |
_types = []
|
|
|
82 |
for col_name, col_type in zip(cols, types):
|
83 |
-
if col_name in
|
84 |
continue
|
85 |
_cols.append(col_name)
|
86 |
_types.append(col_type)
|
87 |
-
cols =
|
88 |
-
types =
|
89 |
return cols, types
|
90 |
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
FIXED_COLS = [c.name for _, _, c in fixed_cols]
|
95 |
-
FIXED_COLS_TYPES = [c.type for _, _, c in fixed_cols]
|
96 |
-
|
97 |
-
|
98 |
-
def select_columns(
|
99 |
-
df: pd.DataFrame,
|
100 |
-
domain_query: list,
|
101 |
-
language_query: list,
|
102 |
-
task: str = "qa",
|
103 |
-
reset_ranking: bool = True
|
104 |
-
) -> pd.DataFrame:
|
105 |
-
cols, _ = get_default_cols(task=task, columns=df.columns, add_fix_cols=False)
|
106 |
selected_cols = []
|
107 |
for c in cols:
|
108 |
-
if task ==
|
109 |
-
eval_col =
|
110 |
-
elif task ==
|
111 |
-
eval_col =
|
112 |
-
|
|
|
|
|
113 |
continue
|
114 |
-
if eval_col.lang not in
|
115 |
continue
|
116 |
selected_cols.append(c)
|
117 |
# We use COLS to maintain sorting
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
if reset_ranking:
|
120 |
filtered_df[COL_NAME_AVG] = filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
|
121 |
filtered_df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
|
122 |
filtered_df.reset_index(inplace=True, drop=True)
|
123 |
filtered_df = reset_rank(filtered_df)
|
124 |
-
|
125 |
return filtered_df
|
126 |
|
127 |
|
128 |
-
def
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
138 |
):
|
139 |
-
filtered_df =
|
140 |
if not show_anonymous:
|
141 |
filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
|
142 |
filtered_df = filter_models(filtered_df, reranking_query)
|
143 |
filtered_df = filter_queries(query, filtered_df)
|
144 |
-
filtered_df = select_columns(filtered_df, domains, langs, task, reset_ranking)
|
145 |
if not show_revision_and_timestamp:
|
146 |
filtered_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
147 |
return filtered_df
|
148 |
|
149 |
|
150 |
-
def
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
return _update_table(
|
161 |
-
"qa", hidden_df, domains, langs, reranking_query, query, show_anonymous, reset_ranking, show_revision_and_timestamp)
|
162 |
-
|
163 |
-
|
164 |
-
def update_table_long_doc(
|
165 |
-
hidden_df: pd.DataFrame,
|
166 |
-
domains: list,
|
167 |
-
langs: list,
|
168 |
-
reranking_query: list,
|
169 |
-
query: str,
|
170 |
-
show_anonymous: bool,
|
171 |
-
show_revision_and_timestamp: bool = False,
|
172 |
-
reset_ranking: bool = True
|
173 |
-
|
174 |
):
|
175 |
-
return
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
|
179 |
def update_metric(
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
) -> pd.DataFrame:
|
190 |
-
if task ==
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
query,
|
209 |
-
show_anonymous,
|
210 |
-
show_revision_and_timestamp
|
211 |
-
)
|
212 |
|
213 |
|
214 |
def upload_file(filepath: str):
|
@@ -218,7 +223,6 @@ def upload_file(filepath: str):
|
|
218 |
return filepath
|
219 |
|
220 |
|
221 |
-
|
222 |
def get_iso_format_timestamp():
|
223 |
# Get the current timestamp with UTC as the timezone
|
224 |
current_timestamp = datetime.now(timezone.utc)
|
@@ -227,15 +231,15 @@ def get_iso_format_timestamp():
|
|
227 |
current_timestamp = current_timestamp.replace(microsecond=0)
|
228 |
|
229 |
# Convert to ISO 8601 format and replace the offset with 'Z'
|
230 |
-
iso_format_timestamp = current_timestamp.isoformat().replace(
|
231 |
-
filename_friendly_timestamp = current_timestamp.strftime(
|
232 |
return iso_format_timestamp, filename_friendly_timestamp
|
233 |
|
234 |
|
235 |
def calculate_file_md5(file_path):
|
236 |
md5 = hashlib.md5()
|
237 |
|
238 |
-
with open(file_path,
|
239 |
while True:
|
240 |
data = f.read(4096)
|
241 |
if not data:
|
@@ -246,13 +250,14 @@ def calculate_file_md5(file_path):
|
|
246 |
|
247 |
|
248 |
def submit_results(
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
256 |
if not filepath.endswith(".zip"):
|
257 |
return styled_error(f"file uploading aborted. wrong file type: {filepath}")
|
258 |
|
@@ -265,11 +270,13 @@ def submit_results(
|
|
265 |
if not model_url.startswith("https://") and not model_url.startswith("http://"):
|
266 |
# TODO: retrieve the model page and find the model name on the page
|
267 |
return styled_error(
|
268 |
-
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
|
|
269 |
if reranking_model != "NoReranker":
|
270 |
if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
|
271 |
return styled_error(
|
272 |
-
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
|
|
273 |
|
274 |
# rename the uploaded file
|
275 |
input_fp = Path(filepath)
|
@@ -279,14 +286,15 @@ def submit_results(
|
|
279 |
input_folder_path = input_fp.parent
|
280 |
|
281 |
if not reranking_model:
|
282 |
-
reranking_model =
|
283 |
-
|
284 |
API.upload_file(
|
285 |
path_or_fileobj=filepath,
|
286 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
|
287 |
repo_id=SEARCH_RESULTS_REPO,
|
288 |
repo_type="dataset",
|
289 |
-
commit_message=f"feat: submit {model} to evaluate"
|
|
|
290 |
|
291 |
output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
|
292 |
output_config = {
|
@@ -297,7 +305,7 @@ def submit_results(
|
|
297 |
"version": f"{version}",
|
298 |
"is_anonymous": is_anonymous,
|
299 |
"revision": f"{revision}",
|
300 |
-
"timestamp": f"{timestamp_config}"
|
301 |
}
|
302 |
with open(input_folder_path / output_config_fn, "w") as f:
|
303 |
json.dump(output_config, f, indent=4, ensure_ascii=False)
|
@@ -306,7 +314,8 @@ def submit_results(
|
|
306 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
|
307 |
repo_id=SEARCH_RESULTS_REPO,
|
308 |
repo_type="dataset",
|
309 |
-
commit_message=f"feat: submit {model} + {reranking_model} config"
|
|
|
310 |
return styled_message(
|
311 |
f"Thanks for submission!\n"
|
312 |
f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
|
@@ -316,3 +325,125 @@ def submit_results(
|
|
316 |
def reset_rank(df):
|
317 |
df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
|
318 |
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import hashlib
|
2 |
+
import json
|
3 |
+
import re
|
4 |
from datetime import datetime, timezone
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
+
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
10 |
+
from src.columns import (
|
11 |
+
COL_NAME_AVG,
|
12 |
+
COL_NAME_IS_ANONYMOUS,
|
13 |
+
COL_NAME_RANK,
|
14 |
+
COL_NAME_RERANKING_MODEL,
|
15 |
+
COL_NAME_RETRIEVAL_MODEL,
|
16 |
+
COL_NAME_REVISION,
|
17 |
+
COL_NAME_TIMESTAMP,
|
18 |
+
get_default_col_names_and_types,
|
19 |
+
get_fixed_col_names_and_types,
|
20 |
+
)
|
21 |
+
from src.envs import API, LATEST_BENCHMARK_VERSION, SEARCH_RESULTS_REPO
|
22 |
+
from src.models import TaskType, get_safe_name
|
23 |
+
|
24 |
+
|
25 |
+
def calculate_mean(row):
|
26 |
+
if pd.isna(row).any():
|
27 |
+
return -1
|
28 |
+
else:
|
29 |
+
return row.mean()
|
30 |
|
31 |
|
32 |
def remove_html(input_str):
|
33 |
# Regular expression for finding HTML tags
|
34 |
+
clean = re.sub(r"<.*?>", "", input_str)
|
35 |
return clean
|
36 |
|
37 |
|
|
|
68 |
return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
|
69 |
|
70 |
|
71 |
+
def get_default_cols(task: TaskType, version_slug, add_fix_cols: bool = True) -> tuple:
|
72 |
cols = []
|
73 |
types = []
|
74 |
+
if task == TaskType.qa:
|
75 |
+
benchmarks = QABenchmarks[version_slug]
|
76 |
+
elif task == TaskType.long_doc:
|
77 |
+
benchmarks = LongDocBenchmarks[version_slug]
|
|
|
|
|
|
|
|
|
78 |
else:
|
79 |
+
raise NotImplementedError
|
80 |
+
cols_list, types_list = get_default_col_names_and_types(benchmarks)
|
81 |
+
benchmark_list = [c.value.col_name for c in list(benchmarks.value)]
|
82 |
for col_name, col_type in zip(cols_list, types_list):
|
83 |
if col_name not in benchmark_list:
|
84 |
continue
|
|
|
|
|
85 |
cols.append(col_name)
|
86 |
types.append(col_type)
|
|
|
87 |
if add_fix_cols:
|
88 |
_cols = []
|
89 |
_types = []
|
90 |
+
fixed_cols, fixed_cols_types = get_fixed_col_names_and_types()
|
91 |
for col_name, col_type in zip(cols, types):
|
92 |
+
if col_name in fixed_cols:
|
93 |
continue
|
94 |
_cols.append(col_name)
|
95 |
_types.append(col_type)
|
96 |
+
cols = fixed_cols + _cols
|
97 |
+
types = fixed_cols_types + _types
|
98 |
return cols, types
|
99 |
|
100 |
|
101 |
+
def get_selected_cols(task, version_slug, domains, languages):
|
102 |
+
cols, _ = get_default_cols(task=task, version_slug=version_slug, add_fix_cols=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
selected_cols = []
|
104 |
for c in cols:
|
105 |
+
if task == TaskType.qa:
|
106 |
+
eval_col = QABenchmarks[version_slug].value[c].value
|
107 |
+
elif task == TaskType.long_doc:
|
108 |
+
eval_col = LongDocBenchmarks[version_slug].value[c].value
|
109 |
+
else:
|
110 |
+
raise NotImplementedError
|
111 |
+
if eval_col.domain not in domains:
|
112 |
continue
|
113 |
+
if eval_col.lang not in languages:
|
114 |
continue
|
115 |
selected_cols.append(c)
|
116 |
# We use COLS to maintain sorting
|
117 |
+
return selected_cols
|
118 |
+
|
119 |
+
|
120 |
+
def select_columns(
|
121 |
+
df: pd.DataFrame,
|
122 |
+
domains: list,
|
123 |
+
languages: list,
|
124 |
+
task: TaskType = TaskType.qa,
|
125 |
+
reset_ranking: bool = True,
|
126 |
+
version_slug: str = None,
|
127 |
+
) -> pd.DataFrame:
|
128 |
+
selected_cols = get_selected_cols(task, version_slug, domains, languages)
|
129 |
+
fixed_cols, _ = get_fixed_col_names_and_types()
|
130 |
+
filtered_df = df[fixed_cols + selected_cols]
|
131 |
+
filtered_df.replace({"": pd.NA}, inplace=True)
|
132 |
if reset_ranking:
|
133 |
filtered_df[COL_NAME_AVG] = filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
|
134 |
filtered_df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
|
135 |
filtered_df.reset_index(inplace=True, drop=True)
|
136 |
filtered_df = reset_rank(filtered_df)
|
|
|
137 |
return filtered_df
|
138 |
|
139 |
|
140 |
+
def _update_df_elem(
|
141 |
+
task: TaskType,
|
142 |
+
version: str,
|
143 |
+
source_df: pd.DataFrame,
|
144 |
+
domains: list,
|
145 |
+
langs: list,
|
146 |
+
reranking_query: list,
|
147 |
+
query: str,
|
148 |
+
show_anonymous: bool,
|
149 |
+
reset_ranking: bool = True,
|
150 |
+
show_revision_and_timestamp: bool = False,
|
151 |
):
|
152 |
+
filtered_df = source_df.copy()
|
153 |
if not show_anonymous:
|
154 |
filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
|
155 |
filtered_df = filter_models(filtered_df, reranking_query)
|
156 |
filtered_df = filter_queries(query, filtered_df)
|
157 |
+
filtered_df = select_columns(filtered_df, domains, langs, task, reset_ranking, get_safe_name(version))
|
158 |
if not show_revision_and_timestamp:
|
159 |
filtered_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
160 |
return filtered_df
|
161 |
|
162 |
|
163 |
+
def update_doc_df_elem(
|
164 |
+
version: str,
|
165 |
+
hidden_df: pd.DataFrame,
|
166 |
+
domains: list,
|
167 |
+
langs: list,
|
168 |
+
reranking_query: list,
|
169 |
+
query: str,
|
170 |
+
show_anonymous: bool,
|
171 |
+
show_revision_and_timestamp: bool = False,
|
172 |
+
reset_ranking: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
):
|
174 |
+
return _update_df_elem(
|
175 |
+
TaskType.long_doc,
|
176 |
+
version,
|
177 |
+
hidden_df,
|
178 |
+
domains,
|
179 |
+
langs,
|
180 |
+
reranking_query,
|
181 |
+
query,
|
182 |
+
show_anonymous,
|
183 |
+
reset_ranking,
|
184 |
+
show_revision_and_timestamp,
|
185 |
+
)
|
186 |
|
187 |
|
188 |
def update_metric(
|
189 |
+
datastore,
|
190 |
+
task: TaskType,
|
191 |
+
metric: str,
|
192 |
+
domains: list,
|
193 |
+
langs: list,
|
194 |
+
reranking_model: list,
|
195 |
+
query: str,
|
196 |
+
show_anonymous: bool = False,
|
197 |
+
show_revision_and_timestamp: bool = False,
|
198 |
) -> pd.DataFrame:
|
199 |
+
if task == TaskType.qa:
|
200 |
+
update_func = update_qa_df_elem
|
201 |
+
elif task == TaskType.long_doc:
|
202 |
+
update_func = update_doc_df_elem
|
203 |
+
else:
|
204 |
+
raise NotImplementedError
|
205 |
+
df_elem = get_leaderboard_df(datastore, task=task, metric=metric)
|
206 |
+
version = datastore.version
|
207 |
+
return update_func(
|
208 |
+
version,
|
209 |
+
df_elem,
|
210 |
+
domains,
|
211 |
+
langs,
|
212 |
+
reranking_model,
|
213 |
+
query,
|
214 |
+
show_anonymous,
|
215 |
+
show_revision_and_timestamp,
|
216 |
+
)
|
|
|
|
|
|
|
|
|
217 |
|
218 |
|
219 |
def upload_file(filepath: str):
|
|
|
223 |
return filepath
|
224 |
|
225 |
|
|
|
226 |
def get_iso_format_timestamp():
|
227 |
# Get the current timestamp with UTC as the timezone
|
228 |
current_timestamp = datetime.now(timezone.utc)
|
|
|
231 |
current_timestamp = current_timestamp.replace(microsecond=0)
|
232 |
|
233 |
# Convert to ISO 8601 format and replace the offset with 'Z'
|
234 |
+
iso_format_timestamp = current_timestamp.isoformat().replace("+00:00", "Z")
|
235 |
+
filename_friendly_timestamp = current_timestamp.strftime("%Y%m%d%H%M%S")
|
236 |
return iso_format_timestamp, filename_friendly_timestamp
|
237 |
|
238 |
|
239 |
def calculate_file_md5(file_path):
|
240 |
md5 = hashlib.md5()
|
241 |
|
242 |
+
with open(file_path, "rb") as f:
|
243 |
while True:
|
244 |
data = f.read(4096)
|
245 |
if not data:
|
|
|
250 |
|
251 |
|
252 |
def submit_results(
|
253 |
+
filepath: str,
|
254 |
+
model: str,
|
255 |
+
model_url: str,
|
256 |
+
reranking_model: str = "",
|
257 |
+
reranking_model_url: str = "",
|
258 |
+
version: str = LATEST_BENCHMARK_VERSION,
|
259 |
+
is_anonymous=False,
|
260 |
+
):
|
261 |
if not filepath.endswith(".zip"):
|
262 |
return styled_error(f"file uploading aborted. wrong file type: {filepath}")
|
263 |
|
|
|
270 |
if not model_url.startswith("https://") and not model_url.startswith("http://"):
|
271 |
# TODO: retrieve the model page and find the model name on the page
|
272 |
return styled_error(
|
273 |
+
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
274 |
+
)
|
275 |
if reranking_model != "NoReranker":
|
276 |
if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
|
277 |
return styled_error(
|
278 |
+
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
279 |
+
)
|
280 |
|
281 |
# rename the uploaded file
|
282 |
input_fp = Path(filepath)
|
|
|
286 |
input_folder_path = input_fp.parent
|
287 |
|
288 |
if not reranking_model:
|
289 |
+
reranking_model = "NoReranker"
|
290 |
+
|
291 |
API.upload_file(
|
292 |
path_or_fileobj=filepath,
|
293 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
|
294 |
repo_id=SEARCH_RESULTS_REPO,
|
295 |
repo_type="dataset",
|
296 |
+
commit_message=f"feat: submit {model} to evaluate",
|
297 |
+
)
|
298 |
|
299 |
output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
|
300 |
output_config = {
|
|
|
305 |
"version": f"{version}",
|
306 |
"is_anonymous": is_anonymous,
|
307 |
"revision": f"{revision}",
|
308 |
+
"timestamp": f"{timestamp_config}",
|
309 |
}
|
310 |
with open(input_folder_path / output_config_fn, "w") as f:
|
311 |
json.dump(output_config, f, indent=4, ensure_ascii=False)
|
|
|
314 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
|
315 |
repo_id=SEARCH_RESULTS_REPO,
|
316 |
repo_type="dataset",
|
317 |
+
commit_message=f"feat: submit {model} + {reranking_model} config",
|
318 |
+
)
|
319 |
return styled_message(
|
320 |
f"Thanks for submission!\n"
|
321 |
f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
|
|
|
325 |
def reset_rank(df):
|
326 |
df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
|
327 |
return df
|
328 |
+
|
329 |
+
|
330 |
+
def get_leaderboard_df(datastore, task: TaskType, metric: str) -> pd.DataFrame:
|
331 |
+
"""
|
332 |
+
Creates a dataframe from all the individual experiment results
|
333 |
+
"""
|
334 |
+
# load the selected metrics into a DataFrame from the raw json
|
335 |
+
all_data_json = []
|
336 |
+
for v in datastore.raw_data:
|
337 |
+
all_data_json += v.to_dict(task=task.value, metric=metric)
|
338 |
+
df = pd.DataFrame.from_records(all_data_json)
|
339 |
+
|
340 |
+
# calculate the average scores for selected task
|
341 |
+
if task == TaskType.qa:
|
342 |
+
benchmarks = QABenchmarks[datastore.slug]
|
343 |
+
elif task == TaskType.long_doc:
|
344 |
+
benchmarks = LongDocBenchmarks[datastore.slug]
|
345 |
+
else:
|
346 |
+
raise NotImplementedError
|
347 |
+
valid_cols = frozenset(df.columns.to_list())
|
348 |
+
benchmark_cols = []
|
349 |
+
for t in list(benchmarks.value):
|
350 |
+
if t.value.col_name not in valid_cols:
|
351 |
+
continue
|
352 |
+
benchmark_cols.append(t.value.col_name)
|
353 |
+
|
354 |
+
# filter out the columns that are not in the data
|
355 |
+
df[COL_NAME_AVG] = df[list(benchmark_cols)].apply(calculate_mean, axis=1).round(decimals=2)
|
356 |
+
df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
|
357 |
+
df.reset_index(inplace=True, drop=True)
|
358 |
+
|
359 |
+
# filter out columns that are not in the data
|
360 |
+
display_cols = [COL_NAME_IS_ANONYMOUS, COL_NAME_AVG]
|
361 |
+
default_cols, _ = get_default_col_names_and_types(benchmarks)
|
362 |
+
for col in default_cols:
|
363 |
+
if col in valid_cols:
|
364 |
+
display_cols.append(col)
|
365 |
+
df = df[display_cols].round(decimals=2)
|
366 |
+
|
367 |
+
# rank the scores
|
368 |
+
df = reset_rank(df)
|
369 |
+
|
370 |
+
# shorten the revision
|
371 |
+
df[COL_NAME_REVISION] = df[COL_NAME_REVISION].str[:6]
|
372 |
+
|
373 |
+
return df
|
374 |
+
|
375 |
+
|
376 |
+
def set_listeners(
|
377 |
+
task: TaskType,
|
378 |
+
target_df,
|
379 |
+
source_df,
|
380 |
+
search_bar,
|
381 |
+
version,
|
382 |
+
selected_domains,
|
383 |
+
selected_langs,
|
384 |
+
selected_rerankings,
|
385 |
+
show_anonymous,
|
386 |
+
show_revision_and_timestamp,
|
387 |
+
):
|
388 |
+
if task == TaskType.qa:
|
389 |
+
update_table_func = update_qa_df_elem
|
390 |
+
elif task == TaskType.long_doc:
|
391 |
+
update_table_func = update_doc_df_elem
|
392 |
+
else:
|
393 |
+
raise NotImplementedError
|
394 |
+
selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
|
395 |
+
search_bar_args = [
|
396 |
+
source_df,
|
397 |
+
version,
|
398 |
+
] + selector_list
|
399 |
+
selector_args = (
|
400 |
+
[version, source_df]
|
401 |
+
+ selector_list
|
402 |
+
+ [
|
403 |
+
show_revision_and_timestamp,
|
404 |
+
]
|
405 |
+
)
|
406 |
+
# Set search_bar listener
|
407 |
+
search_bar.submit(update_table_func, search_bar_args, target_df)
|
408 |
+
|
409 |
+
# Set column-wise listener
|
410 |
+
for selector in selector_list:
|
411 |
+
selector.change(
|
412 |
+
update_table_func,
|
413 |
+
selector_args,
|
414 |
+
target_df,
|
415 |
+
queue=True,
|
416 |
+
)
|
417 |
+
|
418 |
+
|
419 |
+
def update_qa_df_elem(
|
420 |
+
version: str,
|
421 |
+
hidden_df: pd.DataFrame,
|
422 |
+
domains: list,
|
423 |
+
langs: list,
|
424 |
+
reranking_query: list,
|
425 |
+
query: str,
|
426 |
+
show_anonymous: bool,
|
427 |
+
show_revision_and_timestamp: bool = False,
|
428 |
+
reset_ranking: bool = True,
|
429 |
+
):
|
430 |
+
return _update_df_elem(
|
431 |
+
TaskType.qa,
|
432 |
+
version,
|
433 |
+
hidden_df,
|
434 |
+
domains,
|
435 |
+
langs,
|
436 |
+
reranking_query,
|
437 |
+
query,
|
438 |
+
show_anonymous,
|
439 |
+
reset_ranking,
|
440 |
+
show_revision_and_timestamp,
|
441 |
+
)
|
442 |
+
|
443 |
+
|
444 |
+
def styled_error(error):
|
445 |
+
return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
|
446 |
+
|
447 |
+
|
448 |
+
def styled_message(message):
|
449 |
+
return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
|
tests/src/display/test_utils.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
import pytest
|
2 |
-
from src.display.utils import fields, AutoEvalColumnQA, COLS_QA, COLS_LONG_DOC, COLS_LITE, TYPES_QA, TYPES_LONG_DOC, QA_BENCHMARK_COLS, LONG_DOC_BENCHMARK_COLS, get_default_auto_eval_column_dict
|
3 |
-
|
4 |
-
|
5 |
-
def test_fields():
|
6 |
-
for c in fields(AutoEvalColumnQA):
|
7 |
-
print(c)
|
8 |
-
|
9 |
-
|
10 |
-
def test_macro_variables():
|
11 |
-
print(f'COLS_QA: {COLS_QA}')
|
12 |
-
print(f'COLS_LONG_DOC: {COLS_LONG_DOC}')
|
13 |
-
print(f'COLS_LITE: {COLS_LITE}')
|
14 |
-
print(f'TYPES_QA: {TYPES_QA}')
|
15 |
-
print(f'TYPES_LONG_DOC: {TYPES_LONG_DOC}')
|
16 |
-
print(f'QA_BENCHMARK_COLS: {QA_BENCHMARK_COLS}')
|
17 |
-
print(f'LONG_DOC_BENCHMARK_COLS: {LONG_DOC_BENCHMARK_COLS}')
|
18 |
-
|
19 |
-
|
20 |
-
def test_get_default_auto_eval_column_dict():
|
21 |
-
auto_eval_column_dict_list = get_default_auto_eval_column_dict()
|
22 |
-
assert len(auto_eval_column_dict_list) == 9
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/src/test_benchmarks.py
CHANGED
@@ -1,9 +1,33 @@
|
|
1 |
-
|
2 |
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
|
3 |
+
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
4 |
+
from src.envs import BENCHMARK_VERSION_LIST
|
5 |
|
6 |
+
# Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
|
7 |
+
# 24.05
|
8 |
+
# | Task | dev | test |
|
9 |
+
# | ---- | --- | ---- |
|
10 |
+
# | Long-Doc | 4 | 11 |
|
11 |
+
# | QA | 54 | 53 |
|
12 |
+
#
|
13 |
+
# 24.04
|
14 |
+
# | Task | test |
|
15 |
+
# | ---- | ---- |
|
16 |
+
# | Long-Doc | 15 |
|
17 |
+
# | QA | 13 |
|
18 |
|
19 |
|
20 |
+
@pytest.mark.parametrize("num_datasets_dict", [{"air_bench_2404": 13, "air_bench_2405": 53}])
|
21 |
+
def test_qa_benchmarks(num_datasets_dict):
|
22 |
+
assert len(QABenchmarks) == len(BENCHMARK_VERSION_LIST)
|
23 |
+
for benchmark_list in list(QABenchmarks):
|
24 |
+
version_slug = benchmark_list.name
|
25 |
+
assert num_datasets_dict[version_slug] == len(benchmark_list.value)
|
26 |
+
|
27 |
+
|
28 |
+
@pytest.mark.parametrize("num_datasets_dict", [{"air_bench_2404": 15, "air_bench_2405": 11}])
|
29 |
+
def test_doc_benchmarks(num_datasets_dict):
|
30 |
+
assert len(LongDocBenchmarks) == len(BENCHMARK_VERSION_LIST)
|
31 |
+
for benchmark_list in list(LongDocBenchmarks):
|
32 |
+
version_slug = benchmark_list.name
|
33 |
+
assert num_datasets_dict[version_slug] == len(benchmark_list.value)
|
tests/src/test_columns.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
|
3 |
+
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
4 |
+
from src.columns import (
|
5 |
+
COL_NAME_AVG,
|
6 |
+
COL_NAME_RANK,
|
7 |
+
COL_NAME_RERANKING_MODEL,
|
8 |
+
COL_NAME_RETRIEVAL_MODEL,
|
9 |
+
COL_NAME_REVISION,
|
10 |
+
COL_NAME_TIMESTAMP,
|
11 |
+
get_default_auto_eval_column_dict,
|
12 |
+
get_default_col_names_and_types,
|
13 |
+
get_fixed_col_names_and_types,
|
14 |
+
make_autoevalcolumn,
|
15 |
+
)
|
16 |
+
|
17 |
+
# Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
|
18 |
+
# 24.05
|
19 |
+
# | Task | dev | test |
|
20 |
+
# | ---- | --- | ---- |
|
21 |
+
# | Long-Doc | 4 | 11 |
|
22 |
+
# | QA | 54 | 53 |
|
23 |
+
#
|
24 |
+
# 24.04
|
25 |
+
# | Task | test |
|
26 |
+
# | ---- | ---- |
|
27 |
+
# | Long-Doc | 15 |
|
28 |
+
# | QA | 13 |
|
29 |
+
|
30 |
+
|
31 |
+
@pytest.fixture()
|
32 |
+
def expected_col_names():
|
33 |
+
return [
|
34 |
+
"rank",
|
35 |
+
"retrieval_model",
|
36 |
+
"reranking_model",
|
37 |
+
"revision",
|
38 |
+
"timestamp",
|
39 |
+
"average",
|
40 |
+
"retrieval_model_link",
|
41 |
+
"reranking_model_link",
|
42 |
+
"is_anonymous",
|
43 |
+
]
|
44 |
+
|
45 |
+
|
46 |
+
@pytest.fixture()
|
47 |
+
def expected_hidden_col_names():
|
48 |
+
return [
|
49 |
+
"retrieval_model_link",
|
50 |
+
"reranking_model_link",
|
51 |
+
"is_anonymous",
|
52 |
+
]
|
53 |
+
|
54 |
+
|
55 |
+
def test_get_default_auto_eval_column_dict(expected_col_names, expected_hidden_col_names):
|
56 |
+
col_list = get_default_auto_eval_column_dict()
|
57 |
+
assert len(col_list) == 9
|
58 |
+
hidden_cols = []
|
59 |
+
for col_tuple, expected_col in zip(col_list, expected_col_names):
|
60 |
+
col, _, col_content = col_tuple
|
61 |
+
assert col == expected_col
|
62 |
+
if col_content.hidden:
|
63 |
+
hidden_cols.append(col)
|
64 |
+
assert hidden_cols == expected_hidden_col_names
|
65 |
+
|
66 |
+
|
67 |
+
def test_get_fixed_col_names_and_types():
|
68 |
+
col_names, col_types = get_fixed_col_names_and_types()
|
69 |
+
assert len(col_names) == 6
|
70 |
+
assert len(col_types) == 6
|
71 |
+
expected_col_and_type = [
|
72 |
+
(COL_NAME_RANK, "number"),
|
73 |
+
(COL_NAME_RETRIEVAL_MODEL, "markdown"),
|
74 |
+
(COL_NAME_RERANKING_MODEL, "markdown"),
|
75 |
+
(COL_NAME_REVISION, "markdown"),
|
76 |
+
(COL_NAME_TIMESTAMP, "date"),
|
77 |
+
(COL_NAME_AVG, "number"),
|
78 |
+
]
|
79 |
+
for col_name, col_type, (c_name, c_type) in zip(col_names, col_types, expected_col_and_type):
|
80 |
+
assert col_name == c_name
|
81 |
+
assert col_type == c_type
|
82 |
+
|
83 |
+
|
84 |
+
@pytest.mark.parametrize(
|
85 |
+
"benchmarks, expected_benchmark_len",
|
86 |
+
[
|
87 |
+
(QABenchmarks, {"air_bench_2404": 13, "air_bench_2405": 53}),
|
88 |
+
(LongDocBenchmarks, {"air_bench_2404": 15, "air_bench_2405": 11}),
|
89 |
+
],
|
90 |
+
)
|
91 |
+
def test_make_autoevalcolumn(benchmarks, expected_benchmark_len, expected_col_names):
|
92 |
+
expected_default_attrs = frozenset(expected_col_names)
|
93 |
+
for benchmark in benchmarks:
|
94 |
+
TestEvalColumn = make_autoevalcolumn("TestEvalColumn", benchmark)
|
95 |
+
attrs = []
|
96 |
+
for k, v in TestEvalColumn.__dict__.items():
|
97 |
+
if not k.startswith("__"):
|
98 |
+
attrs.append(k)
|
99 |
+
attrs = frozenset(attrs)
|
100 |
+
assert expected_default_attrs.issubset(attrs)
|
101 |
+
benchmark_attrs = attrs.difference(expected_default_attrs)
|
102 |
+
assert len(benchmark_attrs) == expected_benchmark_len[benchmark.name]
|
103 |
+
|
104 |
+
|
105 |
+
@pytest.mark.parametrize(
|
106 |
+
"benchmarks, expected_benchmark_len",
|
107 |
+
[
|
108 |
+
(QABenchmarks, {"air_bench_2404": 13, "air_bench_2405": 53}),
|
109 |
+
(LongDocBenchmarks, {"air_bench_2404": 15, "air_bench_2405": 11}),
|
110 |
+
],
|
111 |
+
)
|
112 |
+
def test_get_default_col_names_and_types(
|
113 |
+
benchmarks, expected_benchmark_len, expected_col_names, expected_hidden_col_names
|
114 |
+
):
|
115 |
+
default_col_len = len(expected_col_names)
|
116 |
+
hidden_col_len = len(expected_hidden_col_names)
|
117 |
+
for benchmark in benchmarks:
|
118 |
+
col_names, col_types = get_default_col_names_and_types(benchmark)
|
119 |
+
assert len(col_names) == expected_benchmark_len[benchmark.name] + default_col_len - hidden_col_len
|
tests/src/test_envs.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from air_benchmark.tasks import BenchmarkTable
|
2 |
+
|
3 |
+
from src.envs import BENCHMARK_VERSION_LIST, DEFAULT_METRIC_LONG_DOC, DEFAULT_METRIC_QA, METRIC_LIST
|
4 |
+
|
5 |
+
|
6 |
+
def test_benchmark_version_list():
|
7 |
+
leaderboard_versions = frozenset(BENCHMARK_VERSION_LIST)
|
8 |
+
available_versions = frozenset([k for k in BenchmarkTable.keys()])
|
9 |
+
assert leaderboard_versions.issubset(available_versions)
|
10 |
+
|
11 |
+
|
12 |
+
def test_default_metrics():
|
13 |
+
assert DEFAULT_METRIC_QA in METRIC_LIST
|
14 |
+
assert DEFAULT_METRIC_LONG_DOC in METRIC_LIST
|
tests/src/test_loaders.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
from src.loaders import load_eval_results, load_leaderboard_datastore, load_raw_eval_results
|
7 |
+
|
8 |
+
cur_fp = Path(__file__)
|
9 |
+
|
10 |
+
|
11 |
+
@pytest.mark.parametrize("version", ["AIR-Bench_24.04", "AIR-Bench_24.05"])
|
12 |
+
def test_load_raw_eval_results(version):
|
13 |
+
raw_data = load_raw_eval_results(cur_fp.parents[1] / f"toydata/eval_results/{version}")
|
14 |
+
assert len(raw_data) == 1
|
15 |
+
full_eval_result = raw_data[0]
|
16 |
+
expected_attr = [
|
17 |
+
"eval_name",
|
18 |
+
"retrieval_model",
|
19 |
+
"reranking_model",
|
20 |
+
"retrieval_model_link",
|
21 |
+
"reranking_model_link",
|
22 |
+
"results",
|
23 |
+
"timestamp",
|
24 |
+
"revision",
|
25 |
+
"is_anonymous",
|
26 |
+
]
|
27 |
+
result_attr = [k for k in full_eval_result.__dict__.keys() if k[:2] != "__" and k[-2:] != "__"]
|
28 |
+
assert sorted(expected_attr) == sorted(result_attr)
|
29 |
+
|
30 |
+
|
31 |
+
@pytest.mark.parametrize("version", ["AIR-Bench_24.04", "AIR-Bench_24.05"])
|
32 |
+
def test_load_leaderboard_datastore(version):
|
33 |
+
file_path = cur_fp.parents[1] / f"toydata/eval_results/{version}"
|
34 |
+
datastore = load_leaderboard_datastore(file_path, version)
|
35 |
+
for k, v in datastore.__dict__.items():
|
36 |
+
if k[:2] != "__" and k[-2:] != "__":
|
37 |
+
if isinstance(v, list):
|
38 |
+
assert v
|
39 |
+
elif isinstance(v, pd.DataFrame):
|
40 |
+
assert not v.empty
|
41 |
+
|
42 |
+
|
43 |
+
def test_load_eval_results():
|
44 |
+
file_path = cur_fp.parents[1] / "toydata/eval_results/"
|
45 |
+
datastore_dict = load_eval_results(file_path)
|
46 |
+
assert len(datastore_dict) == 2
|
tests/src/test_models.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
|
5 |
+
from src.models import EvalResult, FullEvalResult
|
6 |
+
|
7 |
+
cur_fp = Path(__file__)
|
8 |
+
|
9 |
+
|
10 |
+
# Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
|
11 |
+
# 24.05
|
12 |
+
# | Task | dev | test |
|
13 |
+
# | ---- | --- | ---- |
|
14 |
+
# | Long-Doc | 4 | 11 |
|
15 |
+
# | QA | 54 | 53 |
|
16 |
+
#
|
17 |
+
# 24.04
|
18 |
+
# | Task | test |
|
19 |
+
# | ---- | ---- |
|
20 |
+
# | Long-Doc | 15 |
|
21 |
+
# | QA | 13 |
|
22 |
+
NUM_QA_BENCHMARKS_24_05 = 53
|
23 |
+
NUM_DOC_BENCHMARKS_24_05 = 11
|
24 |
+
NUM_QA_BENCHMARKS_24_04 = 13
|
25 |
+
NUM_DOC_BENCHMARKS_24_04 = 15
|
26 |
+
|
27 |
+
|
28 |
+
def test_eval_result():
|
29 |
+
EvalResult(
|
30 |
+
eval_name="eval_name",
|
31 |
+
retrieval_model="bge-m3",
|
32 |
+
reranking_model="NoReranking",
|
33 |
+
results=[{"domain": "law", "lang": "en", "dataset": "lex_files_500K-600K", "value": 0.45723}],
|
34 |
+
task="qa",
|
35 |
+
metric="ndcg_at_3",
|
36 |
+
timestamp="2024-05-14T03:09:08Z",
|
37 |
+
revision="1e243f14bd295ccdea7a118fe847399d",
|
38 |
+
is_anonymous=True,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
@pytest.mark.parametrize(
|
43 |
+
"file_path",
|
44 |
+
[
|
45 |
+
"AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json",
|
46 |
+
"AIR-Bench_24.05/bge-m3/NoReranker/results.json",
|
47 |
+
],
|
48 |
+
)
|
49 |
+
def test_full_eval_result_init_from_json_file(file_path):
|
50 |
+
json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
|
51 |
+
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
52 |
+
assert json_fp.parents[0].stem == full_eval_result.reranking_model
|
53 |
+
assert json_fp.parents[1].stem == full_eval_result.retrieval_model
|
54 |
+
assert len(full_eval_result.results) == 70
|
55 |
+
|
56 |
+
|
57 |
+
@pytest.mark.parametrize(
|
58 |
+
"file_path, task, expected_num_results",
|
59 |
+
[
|
60 |
+
("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "qa", NUM_QA_BENCHMARKS_24_04),
|
61 |
+
(
|
62 |
+
"AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json",
|
63 |
+
"long-doc",
|
64 |
+
NUM_DOC_BENCHMARKS_24_04,
|
65 |
+
),
|
66 |
+
("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "qa", NUM_QA_BENCHMARKS_24_05),
|
67 |
+
("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_05),
|
68 |
+
],
|
69 |
+
)
|
70 |
+
def test_full_eval_result_to_dict(file_path, task, expected_num_results):
|
71 |
+
json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
|
72 |
+
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
73 |
+
result_dict_list = full_eval_result.to_dict(task)
|
74 |
+
assert len(result_dict_list) == 1
|
75 |
+
result = result_dict_list[0]
|
76 |
+
attr_list = frozenset(
|
77 |
+
[
|
78 |
+
"eval_name",
|
79 |
+
"Retrieval Method",
|
80 |
+
"Reranking Model",
|
81 |
+
"Retrieval Model LINK",
|
82 |
+
"Reranking Model LINK",
|
83 |
+
"Revision",
|
84 |
+
"Submission Date",
|
85 |
+
"Anonymous Submission",
|
86 |
+
]
|
87 |
+
)
|
88 |
+
result_cols = list(result.keys())
|
89 |
+
assert len(result_cols) == (expected_num_results + len(attr_list))
|
tests/src/test_read_evals.py
DELETED
@@ -1,68 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
|
3 |
-
from src.read_evals import FullEvalResult, get_raw_eval_results, get_leaderboard_df
|
4 |
-
|
5 |
-
cur_fp = Path(__file__)
|
6 |
-
|
7 |
-
|
8 |
-
def test_init_from_json_file():
|
9 |
-
json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
|
10 |
-
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
11 |
-
num_different_task_domain_lang_metric_dataset_combination = 6
|
12 |
-
assert len(full_eval_result.results) == \
|
13 |
-
num_different_task_domain_lang_metric_dataset_combination
|
14 |
-
assert full_eval_result.retrieval_model == "bge-m3"
|
15 |
-
assert full_eval_result.reranking_model == "bge-reranker-v2-m3"
|
16 |
-
|
17 |
-
|
18 |
-
def test_to_dict():
|
19 |
-
json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
|
20 |
-
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
21 |
-
result_list = full_eval_result.to_dict(task='qa', metric='ndcg_at_1')
|
22 |
-
assert len(result_list) == 1
|
23 |
-
result_dict = result_list[0]
|
24 |
-
assert result_dict["Retrieval Model"] == "bge-m3"
|
25 |
-
assert result_dict["Reranking Model"] == "bge-reranker-v2-m3"
|
26 |
-
assert result_dict["wiki_en"] is not None
|
27 |
-
assert result_dict["wiki_zh"] is not None
|
28 |
-
|
29 |
-
|
30 |
-
def test_get_raw_eval_results():
|
31 |
-
results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
|
32 |
-
results = get_raw_eval_results(results_path)
|
33 |
-
# only load the latest results
|
34 |
-
assert len(results) == 4
|
35 |
-
assert results[0].eval_name == "bge-base-en-v1.5_NoReranker"
|
36 |
-
assert len(results[0].results) == 70
|
37 |
-
assert results[0].eval_name == "bge-base-en-v1.5_bge-reranker-v2-m3"
|
38 |
-
assert len(results[1].results) == 70
|
39 |
-
|
40 |
-
|
41 |
-
def test_get_leaderboard_df():
|
42 |
-
results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
|
43 |
-
raw_data = get_raw_eval_results(results_path)
|
44 |
-
df = get_leaderboard_df(raw_data, 'qa', 'ndcg_at_10')
|
45 |
-
assert df.shape[0] == 4
|
46 |
-
# the results contain only one embedding model
|
47 |
-
# for i in range(4):
|
48 |
-
# assert df["Retrieval Model"][i] == "bge-m3"
|
49 |
-
# # the results contain only two reranking model
|
50 |
-
# assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
|
51 |
-
# assert df["Reranking Model"][1] == "NoReranker"
|
52 |
-
# assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
|
53 |
-
# assert not df[['Average ⬆️', 'wiki_en', 'wiki_zh', ]].isnull().values.any()
|
54 |
-
|
55 |
-
|
56 |
-
def test_get_leaderboard_df_long_doc():
|
57 |
-
results_path = cur_fp.parents[2] / "toydata" / "test_results"
|
58 |
-
raw_data = get_raw_eval_results(results_path)
|
59 |
-
df = get_leaderboard_df(raw_data, 'long-doc', 'ndcg_at_1')
|
60 |
-
assert df.shape[0] == 2
|
61 |
-
# the results contain only one embedding model
|
62 |
-
for i in range(2):
|
63 |
-
assert df["Retrieval Model"][i] == "bge-m3"
|
64 |
-
# the results contains only two reranking model
|
65 |
-
assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
|
66 |
-
assert df["Reranking Model"][1] == "NoReranker"
|
67 |
-
assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
|
68 |
-
assert not df[['Average ⬆️', 'law_en_lex_files_500k_600k', ]].isnull().values.any()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/src/test_utils.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
|
7 |
+
from src.models import TaskType, model_hyperlink
|
8 |
+
from src.utils import (
|
9 |
+
_update_df_elem,
|
10 |
+
calculate_mean,
|
11 |
+
filter_models,
|
12 |
+
filter_queries,
|
13 |
+
get_default_cols,
|
14 |
+
get_leaderboard_df,
|
15 |
+
get_selected_cols,
|
16 |
+
remove_html,
|
17 |
+
select_columns,
|
18 |
+
)
|
19 |
+
|
20 |
+
cur_fp = Path(__file__)
|
21 |
+
|
22 |
+
NUM_QA_BENCHMARKS_24_05 = 53
|
23 |
+
NUM_DOC_BENCHMARKS_24_05 = 11
|
24 |
+
NUM_QA_BENCHMARKS_24_04 = 13
|
25 |
+
NUM_DOC_BENCHMARKS_24_04 = 15
|
26 |
+
|
27 |
+
|
28 |
+
@pytest.fixture
|
29 |
+
def toy_df():
|
30 |
+
return pd.DataFrame(
|
31 |
+
{
|
32 |
+
"Retrieval Method": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
|
33 |
+
"Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
|
34 |
+
"Rank 🏆": [1, 2, 3, 4],
|
35 |
+
"Revision": ["123", "234", "345", "456"],
|
36 |
+
"Submission Date": ["", "", "", ""],
|
37 |
+
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
38 |
+
"wiki_en": [0.8, 0.7, 0.2, 0.1],
|
39 |
+
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
|
40 |
+
"news_en": [0.8, 0.7, 0.2, 0.1],
|
41 |
+
"news_zh": [0.4, 0.1, 0.2, 0.3],
|
42 |
+
"Anonymous Submission": [False, False, False, True],
|
43 |
+
}
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def test_remove_html():
|
48 |
+
model_name = "jina-embeddings-v3"
|
49 |
+
html_str = model_hyperlink("https://jina.ai", model_name)
|
50 |
+
output_str = remove_html(html_str)
|
51 |
+
assert output_str == model_name
|
52 |
+
|
53 |
+
|
54 |
+
def test_calculate_mean():
|
55 |
+
valid_row = [1, 3]
|
56 |
+
invalid_row = [2, pd.NA]
|
57 |
+
df = pd.DataFrame([valid_row, invalid_row], columns=["a", "b"])
|
58 |
+
result = list(df.apply(calculate_mean, axis=1))
|
59 |
+
assert result[0] == sum(valid_row) / 2
|
60 |
+
assert result[1] == -1
|
61 |
+
|
62 |
+
|
63 |
+
@pytest.mark.parametrize(
|
64 |
+
"models, expected",
|
65 |
+
[
|
66 |
+
(["model1", "model3"], 2),
|
67 |
+
(["model1", "model_missing"], 1),
|
68 |
+
(["model1", "model2", "model3"], 3),
|
69 |
+
(
|
70 |
+
[
|
71 |
+
"model1",
|
72 |
+
],
|
73 |
+
1,
|
74 |
+
),
|
75 |
+
([], 3),
|
76 |
+
],
|
77 |
+
)
|
78 |
+
def test_filter_models(models, expected):
|
79 |
+
df = pd.DataFrame(
|
80 |
+
{
|
81 |
+
COL_NAME_RERANKING_MODEL: [
|
82 |
+
"model1",
|
83 |
+
"model2",
|
84 |
+
"model3",
|
85 |
+
],
|
86 |
+
"col2": [1, 2, 3],
|
87 |
+
}
|
88 |
+
)
|
89 |
+
output_df = filter_models(df, models)
|
90 |
+
assert len(output_df) == expected
|
91 |
+
|
92 |
+
|
93 |
+
@pytest.mark.parametrize(
|
94 |
+
"query, expected",
|
95 |
+
[
|
96 |
+
("model1;model3", 2),
|
97 |
+
("model1;model4", 1),
|
98 |
+
("model1;model2;model3", 3),
|
99 |
+
("model1", 1),
|
100 |
+
("", 3),
|
101 |
+
],
|
102 |
+
)
|
103 |
+
def test_filter_queries(query, expected):
|
104 |
+
df = pd.DataFrame(
|
105 |
+
{
|
106 |
+
COL_NAME_RETRIEVAL_MODEL: [
|
107 |
+
"model1",
|
108 |
+
"model2",
|
109 |
+
"model3",
|
110 |
+
],
|
111 |
+
COL_NAME_RERANKING_MODEL: [
|
112 |
+
"model4",
|
113 |
+
"model5",
|
114 |
+
"model6",
|
115 |
+
],
|
116 |
+
}
|
117 |
+
)
|
118 |
+
output_df = filter_queries(query, df)
|
119 |
+
assert len(output_df) == expected
|
120 |
+
|
121 |
+
|
122 |
+
@pytest.mark.parametrize(
|
123 |
+
"task_type, slug, add_fix_cols, expected",
|
124 |
+
[
|
125 |
+
(TaskType.qa, "air_bench_2404", True, NUM_QA_BENCHMARKS_24_04),
|
126 |
+
(TaskType.long_doc, "air_bench_2404", True, NUM_DOC_BENCHMARKS_24_04),
|
127 |
+
(TaskType.qa, "air_bench_2405", False, NUM_QA_BENCHMARKS_24_05),
|
128 |
+
(TaskType.long_doc, "air_bench_2405", False, NUM_DOC_BENCHMARKS_24_05),
|
129 |
+
],
|
130 |
+
)
|
131 |
+
def test_get_default_cols(task_type, slug, add_fix_cols, expected):
|
132 |
+
attr_cols = ["Rank 🏆", "Retrieval Method", "Reranking Model", "Revision", "Submission Date", "Average ⬆️"]
|
133 |
+
cols, types = get_default_cols(task_type, slug)
|
134 |
+
cols_set = frozenset(cols)
|
135 |
+
attrs_set = frozenset(attr_cols)
|
136 |
+
if add_fix_cols:
|
137 |
+
assert attrs_set.issubset(cols_set)
|
138 |
+
benchmark_cols = list(cols_set.difference(attrs_set))
|
139 |
+
assert len(benchmark_cols) == expected
|
140 |
+
|
141 |
+
|
142 |
+
@pytest.mark.parametrize(
|
143 |
+
"task_type, domains, languages, expected",
|
144 |
+
[
|
145 |
+
(
|
146 |
+
TaskType.qa,
|
147 |
+
["wiki", "news"],
|
148 |
+
[
|
149 |
+
"zh",
|
150 |
+
],
|
151 |
+
["wiki_zh", "news_zh"],
|
152 |
+
),
|
153 |
+
(
|
154 |
+
TaskType.qa,
|
155 |
+
[
|
156 |
+
"law",
|
157 |
+
],
|
158 |
+
["zh", "en"],
|
159 |
+
["law_en"],
|
160 |
+
),
|
161 |
+
(
|
162 |
+
TaskType.long_doc,
|
163 |
+
["healthcare"],
|
164 |
+
["zh", "en"],
|
165 |
+
[
|
166 |
+
"healthcare_en_pubmed_100k_200k_1",
|
167 |
+
"healthcare_en_pubmed_100k_200k_2",
|
168 |
+
"healthcare_en_pubmed_100k_200k_3",
|
169 |
+
"healthcare_en_pubmed_40k_50k_5_merged",
|
170 |
+
"healthcare_en_pubmed_30k_40k_10_merged",
|
171 |
+
],
|
172 |
+
),
|
173 |
+
],
|
174 |
+
)
|
175 |
+
def test_get_selected_cols(task_type, domains, languages, expected):
|
176 |
+
slug = "air_bench_2404"
|
177 |
+
cols = get_selected_cols(task_type, slug, domains, languages)
|
178 |
+
assert sorted(cols) == sorted(expected)
|
179 |
+
|
180 |
+
|
181 |
+
@pytest.mark.parametrize("reset_rank", [False])
|
182 |
+
def test_select_columns(toy_df, reset_rank):
|
183 |
+
expected = [
|
184 |
+
"Rank 🏆",
|
185 |
+
"Retrieval Method",
|
186 |
+
"Reranking Model",
|
187 |
+
"Revision",
|
188 |
+
"Submission Date",
|
189 |
+
"Average ⬆️",
|
190 |
+
"news_zh",
|
191 |
+
]
|
192 |
+
df_result = select_columns(toy_df, ["news"], ["zh"], version_slug="air_bench_2404", reset_ranking=reset_rank)
|
193 |
+
assert len(df_result.columns) == len(expected)
|
194 |
+
if reset_rank:
|
195 |
+
assert df_result["Average ⬆️"].equals(df_result["news_zh"])
|
196 |
+
else:
|
197 |
+
assert df_result["Average ⬆️"].equals(toy_df["Average ⬆️"])
|
198 |
+
|
199 |
+
|
200 |
+
@pytest.mark.parametrize(
|
201 |
+
"reset_rank, show_anony",
|
202 |
+
[
|
203 |
+
(False, True),
|
204 |
+
(True, True),
|
205 |
+
(True, False),
|
206 |
+
],
|
207 |
+
)
|
208 |
+
def test__update_df_elem(toy_df, reset_rank, show_anony):
|
209 |
+
df = _update_df_elem(TaskType.qa, "AIR-Bench_24.04", toy_df, ["news"], ["zh"], [], "", show_anony, reset_rank)
|
210 |
+
if show_anony:
|
211 |
+
assert df.shape[0] == 4
|
212 |
+
else:
|
213 |
+
assert df.shape[0] == 3
|
214 |
+
if show_anony:
|
215 |
+
if reset_rank:
|
216 |
+
assert df["Average ⬆️"].equals(df["news_zh"])
|
217 |
+
else:
|
218 |
+
assert df["Average ⬆️"].equals(toy_df["Average ⬆️"])
|
219 |
+
|
220 |
+
|
221 |
+
@pytest.mark.parametrize(
|
222 |
+
"version, task_type",
|
223 |
+
[
|
224 |
+
("AIR-Bench_24.04", TaskType.qa),
|
225 |
+
("AIR-Bench_24.04", TaskType.long_doc),
|
226 |
+
("AIR-Bench_24.05", TaskType.qa),
|
227 |
+
("AIR-Bench_24.05", TaskType.long_doc),
|
228 |
+
],
|
229 |
+
)
|
230 |
+
def test_get_leaderboard_df(version, task_type):
|
231 |
+
from src.loaders import load_raw_eval_results
|
232 |
+
from src.models import LeaderboardDataStore, get_safe_name
|
233 |
+
|
234 |
+
raw_data = load_raw_eval_results(cur_fp.parents[1] / f"toydata/eval_results/{version}")
|
235 |
+
ds = LeaderboardDataStore(version, get_safe_name(version), raw_data=raw_data)
|
236 |
+
df = get_leaderboard_df(ds, task_type, "ndcg_at_10")
|
237 |
+
assert df.shape[0] == 1
|
tests/test_utils.py
DELETED
@@ -1,115 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import pytest
|
3 |
-
|
4 |
-
from src.utils import filter_models, search_table, filter_queries, select_columns, update_table_long_doc, get_iso_format_timestamp, get_default_cols, update_table
|
5 |
-
from src.display.utils import COL_NAME_IS_ANONYMOUS, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RANK, COL_NAME_AVG
|
6 |
-
|
7 |
-
|
8 |
-
@pytest.fixture
|
9 |
-
def toy_df():
|
10 |
-
return pd.DataFrame(
|
11 |
-
{
|
12 |
-
"Retrieval Model": [
|
13 |
-
"bge-m3",
|
14 |
-
"bge-m3",
|
15 |
-
"jina-embeddings-v2-base",
|
16 |
-
"jina-embeddings-v2-base"
|
17 |
-
],
|
18 |
-
"Reranking Model": [
|
19 |
-
"bge-reranker-v2-m3",
|
20 |
-
"NoReranker",
|
21 |
-
"bge-reranker-v2-m3",
|
22 |
-
"NoReranker"
|
23 |
-
],
|
24 |
-
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
25 |
-
"wiki_en": [0.8, 0.7, 0.2, 0.1],
|
26 |
-
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
|
27 |
-
"news_en": [0.8, 0.7, 0.2, 0.1],
|
28 |
-
"news_zh": [0.4, 0.1, 0.4, 0.3],
|
29 |
-
}
|
30 |
-
)
|
31 |
-
|
32 |
-
|
33 |
-
@pytest.fixture
|
34 |
-
def toy_df_long_doc():
|
35 |
-
return pd.DataFrame(
|
36 |
-
{
|
37 |
-
"Retrieval Model": [
|
38 |
-
"bge-m3",
|
39 |
-
"bge-m3",
|
40 |
-
"jina-embeddings-v2-base",
|
41 |
-
"jina-embeddings-v2-base"
|
42 |
-
],
|
43 |
-
"Reranking Model": [
|
44 |
-
"bge-reranker-v2-m3",
|
45 |
-
"NoReranker",
|
46 |
-
"bge-reranker-v2-m3",
|
47 |
-
"NoReranker"
|
48 |
-
],
|
49 |
-
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
50 |
-
"law_en_lex_files_300k_400k": [0.4, 0.1, 0.4, 0.3],
|
51 |
-
"law_en_lex_files_400k_500k": [0.8, 0.7, 0.2, 0.1],
|
52 |
-
"law_en_lex_files_500k_600k": [0.8, 0.7, 0.2, 0.1],
|
53 |
-
"law_en_lex_files_600k_700k": [0.4, 0.1, 0.4, 0.3],
|
54 |
-
}
|
55 |
-
)
|
56 |
-
def test_filter_models(toy_df):
|
57 |
-
df_result = filter_models(toy_df, ["bge-reranker-v2-m3", ])
|
58 |
-
assert len(df_result) == 2
|
59 |
-
assert df_result.iloc[0]["Reranking Model"] == "bge-reranker-v2-m3"
|
60 |
-
|
61 |
-
|
62 |
-
def test_search_table(toy_df):
|
63 |
-
df_result = search_table(toy_df, "jina")
|
64 |
-
assert len(df_result) == 2
|
65 |
-
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
|
66 |
-
|
67 |
-
|
68 |
-
def test_filter_queries(toy_df):
|
69 |
-
df_result = filter_queries("jina", toy_df)
|
70 |
-
assert len(df_result) == 2
|
71 |
-
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
|
72 |
-
|
73 |
-
|
74 |
-
def test_select_columns(toy_df):
|
75 |
-
df_result = select_columns(toy_df, ['news',], ['zh',])
|
76 |
-
assert len(df_result.columns) == 4
|
77 |
-
assert df_result['Average ⬆️'].equals(df_result['news_zh'])
|
78 |
-
|
79 |
-
|
80 |
-
def test_update_table_long_doc(toy_df_long_doc):
|
81 |
-
df_result = update_table_long_doc(toy_df_long_doc, ['law',], ['en',], ["bge-reranker-v2-m3", ], "jina")
|
82 |
-
print(df_result)
|
83 |
-
|
84 |
-
|
85 |
-
def test_get_iso_format_timestamp():
|
86 |
-
timestamp_config, timestamp_fn = get_iso_format_timestamp()
|
87 |
-
assert len(timestamp_fn) == 14
|
88 |
-
assert len(timestamp_config) == 20
|
89 |
-
assert timestamp_config[-1] == "Z"
|
90 |
-
|
91 |
-
|
92 |
-
def test_get_default_cols():
|
93 |
-
cols, types = get_default_cols("qa")
|
94 |
-
for c, t in zip(cols, types):
|
95 |
-
print(f"type({c}): {t}")
|
96 |
-
assert len(frozenset(cols)) == len(cols)
|
97 |
-
|
98 |
-
|
99 |
-
def test_update_table():
|
100 |
-
df = pd.DataFrame(
|
101 |
-
{
|
102 |
-
COL_NAME_IS_ANONYMOUS: [False, False, False],
|
103 |
-
COL_NAME_REVISION: ["a1", "a2", "a3"],
|
104 |
-
COL_NAME_TIMESTAMP: ["2024-05-12T12:24:02Z"] * 3,
|
105 |
-
COL_NAME_RERANKING_MODEL: ["NoReranker"] * 3,
|
106 |
-
COL_NAME_RETRIEVAL_MODEL: ["Foo"] * 3,
|
107 |
-
COL_NAME_RANK: [1, 2, 3],
|
108 |
-
COL_NAME_AVG: [0.1, 0.2, 0.3], # unsorted values
|
109 |
-
"wiki_en": [0.1, 0.2, 0.3]
|
110 |
-
}
|
111 |
-
)
|
112 |
-
results = update_table(df, "wiki", "en", ["NoReranker"], "", show_anonymous=False, reset_ranking=False, show_revision_and_timestamp=False)
|
113 |
-
# keep the RANK as the same regardless of the unsorted averages
|
114 |
-
assert results[COL_NAME_RANK].to_list() == [1, 2, 3]
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/toydata/eval_results/AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tests/toydata/eval_results/AIR-Bench_24.05/bge-m3/NoReranker/results.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tests/toydata/test_data.json
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
[
|
2 |
-
{
|
3 |
-
"config": {
|
4 |
-
"retrieval_model": "bge-m3",
|
5 |
-
"reranking_model": "bge-reranker-v2-m3",
|
6 |
-
"task": "long_doc",
|
7 |
-
"metric": "ndcg_at_1"
|
8 |
-
},
|
9 |
-
"results": [
|
10 |
-
{
|
11 |
-
"domain": "law",
|
12 |
-
"lang": "en",
|
13 |
-
"dataset": "lex_files_500K-600K",
|
14 |
-
"value": 0.75723
|
15 |
-
}
|
16 |
-
]
|
17 |
-
},
|
18 |
-
{
|
19 |
-
"config": {
|
20 |
-
"retrieval_model": "bge-m3",
|
21 |
-
"reranking_model": "bge-reranker-v2-m3",
|
22 |
-
"task": "long_doc",
|
23 |
-
"metric": "ndcg_at_3"
|
24 |
-
},
|
25 |
-
"results": [
|
26 |
-
{
|
27 |
-
"domain": "law",
|
28 |
-
"lang": "en",
|
29 |
-
"dataset": "lex_files_500K-600K",
|
30 |
-
"value": 0.69909
|
31 |
-
}
|
32 |
-
]
|
33 |
-
},
|
34 |
-
{
|
35 |
-
"config": {
|
36 |
-
"retrieval_model": "bge-m3",
|
37 |
-
"reranking_model": "bge-reranker-v2-m3",
|
38 |
-
"task": "qa",
|
39 |
-
"metric": "ndcg_at_1"
|
40 |
-
},
|
41 |
-
"results": [
|
42 |
-
{
|
43 |
-
"domain": "wiki",
|
44 |
-
"lang": "en",
|
45 |
-
"dataset": "unknown",
|
46 |
-
"value": 0.69083
|
47 |
-
}
|
48 |
-
]
|
49 |
-
},
|
50 |
-
{
|
51 |
-
"config": {
|
52 |
-
"retrieval_model": "bge-m3",
|
53 |
-
"reranking_model": "bge-reranker-v2-m3",
|
54 |
-
"task": "qa",
|
55 |
-
"metric": "ndcg_at_3"
|
56 |
-
},
|
57 |
-
"results": [
|
58 |
-
{
|
59 |
-
"domain": "wiki",
|
60 |
-
"lang": "en",
|
61 |
-
"dataset": "unknown",
|
62 |
-
"value": 0.73359
|
63 |
-
}
|
64 |
-
]
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"config": {
|
68 |
-
"retrieval_model": "bge-m3",
|
69 |
-
"reranking_model": "bge-reranker-v2-m3",
|
70 |
-
"task": "qa",
|
71 |
-
"metric": "ndcg_at_1"
|
72 |
-
},
|
73 |
-
"results": [
|
74 |
-
{
|
75 |
-
"domain": "wiki",
|
76 |
-
"lang": "zh",
|
77 |
-
"dataset": "unknown",
|
78 |
-
"value": 0.78358
|
79 |
-
}
|
80 |
-
]
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"config": {
|
84 |
-
"retrieval_model": "bge-m3",
|
85 |
-
"reranking_model": "bge-reranker-v2-m3",
|
86 |
-
"task": "qa",
|
87 |
-
"metric": "ndcg_at_3"
|
88 |
-
},
|
89 |
-
"results": [
|
90 |
-
{
|
91 |
-
"domain": "wiki",
|
92 |
-
"lang": "zh",
|
93 |
-
"dataset": "unknown",
|
94 |
-
"value": 0.78358
|
95 |
-
}
|
96 |
-
]
|
97 |
-
}
|
98 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/toydata/test_results/bge-m3/NoReranker/results_2023-11-21T18-10-08.json
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
[
|
2 |
-
{
|
3 |
-
"config": {
|
4 |
-
"retrieval_model": "bge-m3",
|
5 |
-
"reranking_model": "NoReranker",
|
6 |
-
"task": "long_doc",
|
7 |
-
"metric": "ndcg_at_1"
|
8 |
-
},
|
9 |
-
"results": [
|
10 |
-
{
|
11 |
-
"domain": "law",
|
12 |
-
"lang": "en",
|
13 |
-
"dataset": "lex_files_500K-600K",
|
14 |
-
"value": 0.45723
|
15 |
-
}
|
16 |
-
]
|
17 |
-
},
|
18 |
-
{
|
19 |
-
"config": {
|
20 |
-
"retrieval_model": "bge-m3",
|
21 |
-
"reranking_model": "NoReranker",
|
22 |
-
"task": "long_doc",
|
23 |
-
"metric": "ndcg_at_3"
|
24 |
-
},
|
25 |
-
"results": [
|
26 |
-
{
|
27 |
-
"domain": "law",
|
28 |
-
"lang": "en",
|
29 |
-
"dataset": "lex_files_500K-600K",
|
30 |
-
"value": 0.49909
|
31 |
-
}
|
32 |
-
]
|
33 |
-
},
|
34 |
-
{
|
35 |
-
"config": {
|
36 |
-
"retrieval_model": "bge-m3",
|
37 |
-
"reranking_model": "NoReranker",
|
38 |
-
"task": "qa",
|
39 |
-
"metric": "ndcg_at_1"
|
40 |
-
},
|
41 |
-
"results": [
|
42 |
-
{
|
43 |
-
"domain": "wiki",
|
44 |
-
"lang": "en",
|
45 |
-
"dataset": "unknown",
|
46 |
-
"value": 0.49083
|
47 |
-
}
|
48 |
-
]
|
49 |
-
},
|
50 |
-
{
|
51 |
-
"config": {
|
52 |
-
"retrieval_model": "bge-m3",
|
53 |
-
"reranking_model": "NoReranker",
|
54 |
-
"task": "qa",
|
55 |
-
"metric": "ndcg_at_3"
|
56 |
-
},
|
57 |
-
"results": [
|
58 |
-
{
|
59 |
-
"domain": "wiki",
|
60 |
-
"lang": "en",
|
61 |
-
"dataset": "unknown",
|
62 |
-
"value": 0.43359
|
63 |
-
}
|
64 |
-
]
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"config": {
|
68 |
-
"retrieval_model": "bge-m3",
|
69 |
-
"reranking_model": "NoReranker",
|
70 |
-
"task": "qa",
|
71 |
-
"metric": "ndcg_at_1"
|
72 |
-
},
|
73 |
-
"results": [
|
74 |
-
{
|
75 |
-
"domain": "wiki",
|
76 |
-
"lang": "zh",
|
77 |
-
"dataset": "unknown",
|
78 |
-
"value": 0.78358
|
79 |
-
}
|
80 |
-
]
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"config": {
|
84 |
-
"retrieval_model": "bge-m3",
|
85 |
-
"reranking_model": "NoReranker",
|
86 |
-
"task": "qa",
|
87 |
-
"metric": "ndcg_at_3"
|
88 |
-
},
|
89 |
-
"results": [
|
90 |
-
{
|
91 |
-
"domain": "wiki",
|
92 |
-
"lang": "zh",
|
93 |
-
"dataset": "unknown",
|
94 |
-
"value": 0.78358
|
95 |
-
}
|
96 |
-
]
|
97 |
-
}
|
98 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/toydata/test_results/bge-m3/bge-reranker-v2-m3/results_2023-11-21T18-10-08.json
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
[
|
2 |
-
{
|
3 |
-
"config": {
|
4 |
-
"retrieval_model": "bge-m3",
|
5 |
-
"reranking_model": "bge-reranker-v2-m3",
|
6 |
-
"task": "long_doc",
|
7 |
-
"metric": "ndcg_at_1"
|
8 |
-
},
|
9 |
-
"results": [
|
10 |
-
{
|
11 |
-
"domain": "law",
|
12 |
-
"lang": "en",
|
13 |
-
"dataset": "lex_files_500K-600K",
|
14 |
-
"value": 0.75723
|
15 |
-
}
|
16 |
-
]
|
17 |
-
},
|
18 |
-
{
|
19 |
-
"config": {
|
20 |
-
"retrieval_model": "bge-m3",
|
21 |
-
"reranking_model": "bge-reranker-v2-m3",
|
22 |
-
"task": "long_doc",
|
23 |
-
"metric": "ndcg_at_3"
|
24 |
-
},
|
25 |
-
"results": [
|
26 |
-
{
|
27 |
-
"domain": "law",
|
28 |
-
"lang": "en",
|
29 |
-
"dataset": "lex_files_500K-600K",
|
30 |
-
"value": 0.69909
|
31 |
-
}
|
32 |
-
]
|
33 |
-
},
|
34 |
-
{
|
35 |
-
"config": {
|
36 |
-
"retrieval_model": "bge-m3",
|
37 |
-
"reranking_model": "bge-reranker-v2-m3",
|
38 |
-
"task": "qa",
|
39 |
-
"metric": "ndcg_at_1"
|
40 |
-
},
|
41 |
-
"results": [
|
42 |
-
{
|
43 |
-
"domain": "wiki",
|
44 |
-
"lang": "en",
|
45 |
-
"dataset": "unknown",
|
46 |
-
"value": 0.69083
|
47 |
-
}
|
48 |
-
]
|
49 |
-
},
|
50 |
-
{
|
51 |
-
"config": {
|
52 |
-
"retrieval_model": "bge-m3",
|
53 |
-
"reranking_model": "bge-reranker-v2-m3",
|
54 |
-
"task": "qa",
|
55 |
-
"metric": "ndcg_at_3"
|
56 |
-
},
|
57 |
-
"results": [
|
58 |
-
{
|
59 |
-
"domain": "wiki",
|
60 |
-
"lang": "en",
|
61 |
-
"dataset": "unknown",
|
62 |
-
"value": 0.73359
|
63 |
-
}
|
64 |
-
]
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"config": {
|
68 |
-
"retrieval_model": "bge-m3",
|
69 |
-
"reranking_model": "bge-reranker-v2-m3",
|
70 |
-
"task": "qa",
|
71 |
-
"metric": "ndcg_at_1"
|
72 |
-
},
|
73 |
-
"results": [
|
74 |
-
{
|
75 |
-
"domain": "wiki",
|
76 |
-
"lang": "zh",
|
77 |
-
"dataset": "unknown",
|
78 |
-
"value": 0.78358
|
79 |
-
}
|
80 |
-
]
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"config": {
|
84 |
-
"retrieval_model": "bge-m3",
|
85 |
-
"reranking_model": "bge-reranker-v2-m3",
|
86 |
-
"task": "qa",
|
87 |
-
"metric": "ndcg_at_3"
|
88 |
-
},
|
89 |
-
"results": [
|
90 |
-
{
|
91 |
-
"domain": "wiki",
|
92 |
-
"lang": "zh",
|
93 |
-
"dataset": "unknown",
|
94 |
-
"value": 0.78358
|
95 |
-
}
|
96 |
-
]
|
97 |
-
}
|
98 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|