Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
refactor: refactor the benchmarks
Browse files- app.py +6 -6
- src/benchmarks.py +27 -18
- src/utils.py +3 -3
- tests/src/test_benchmarks.py +2 -0
app.py
CHANGED
@@ -6,8 +6,8 @@ from src.about import (
|
|
6 |
TITLE
|
7 |
)
|
8 |
from src.benchmarks import (
|
9 |
-
|
10 |
-
|
11 |
)
|
12 |
from src.display.css_html_js import custom_css
|
13 |
from src.envs import (
|
@@ -76,11 +76,11 @@ def update_metric_long_doc(
|
|
76 |
return update_metric(data["AIR-Bench_24.04"].raw_data, "long-doc", metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
|
77 |
|
78 |
|
79 |
-
DOMAIN_COLS_QA = list(frozenset([c.domain for c in
|
80 |
-
LANG_COLS_QA = list(frozenset([c.lang for c in
|
81 |
|
82 |
-
DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in
|
83 |
-
LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in
|
84 |
|
85 |
demo = gr.Blocks(css=custom_css)
|
86 |
|
|
|
6 |
TITLE
|
7 |
)
|
8 |
from src.benchmarks import (
|
9 |
+
BenchmarksQA,
|
10 |
+
BenchmarksLongDoc
|
11 |
)
|
12 |
from src.display.css_html_js import custom_css
|
13 |
from src.envs import (
|
|
|
76 |
return update_metric(data["AIR-Bench_24.04"].raw_data, "long-doc", metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
|
77 |
|
78 |
|
79 |
+
DOMAIN_COLS_QA = list(frozenset([c.value.domain for c in list(BenchmarksQA)]))
|
80 |
+
LANG_COLS_QA = list(frozenset([c.value.lang for c in list(BenchmarksQA)]))
|
81 |
|
82 |
+
DOMAIN_COLS_LONG_DOC = list(frozenset([c.value.domain for c in list(BenchmarksLongDoc)]))
|
83 |
+
LANG_COLS_LONG_DOC = list(frozenset([c.value.lang for c in list(BenchmarksLongDoc)]))
|
84 |
|
85 |
demo = gr.Blocks(css=custom_css)
|
86 |
|
src/benchmarks.py
CHANGED
@@ -25,25 +25,34 @@ class Benchmark:
|
|
25 |
task: str
|
26 |
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
for metric in dataset_list:
|
38 |
-
qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
|
39 |
-
elif task == "long-doc":
|
40 |
-
for dataset in dataset_list:
|
41 |
-
benchmark_name = f"{domain}_{lang}_{dataset}"
|
42 |
benchmark_name = get_safe_name(benchmark_name)
|
43 |
col_name = benchmark_name
|
44 |
-
for metric in
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
lang, task)
|
|
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
25 |
task: str
|
26 |
|
27 |
|
28 |
+
# create a function return an enum class containing all the benchmarks
|
29 |
+
def get_benchmarks_enum(benchmark_version):
|
30 |
+
qa_benchmark_dict = {}
|
31 |
+
long_doc_benchmark_dict = {}
|
32 |
+
for task, domain_dict in BenchmarkTable[benchmark_version].items():
|
33 |
+
for domain, lang_dict in domain_dict.items():
|
34 |
+
for lang, dataset_list in lang_dict.items():
|
35 |
+
if task == "qa":
|
36 |
+
benchmark_name = f"{domain}_{lang}"
|
|
|
|
|
|
|
|
|
|
|
37 |
benchmark_name = get_safe_name(benchmark_name)
|
38 |
col_name = benchmark_name
|
39 |
+
for metric in dataset_list:
|
40 |
+
qa_benchmark_dict[benchmark_name] = \
|
41 |
+
Benchmark(
|
42 |
+
benchmark_name, metric, col_name, domain, lang, task)
|
43 |
+
elif task == "long-doc":
|
44 |
+
for dataset in dataset_list:
|
45 |
+
benchmark_name = f"{domain}_{lang}_{dataset}"
|
46 |
+
benchmark_name = get_safe_name(benchmark_name)
|
47 |
+
col_name = benchmark_name
|
48 |
+
for metric in METRIC_LIST:
|
49 |
+
long_doc_benchmark_dict[benchmark_name] = \
|
50 |
+
Benchmark(
|
51 |
+
benchmark_name, metric, col_name, domain,
|
52 |
lang, task)
|
53 |
+
return qa_benchmark_dict, long_doc_benchmark_dict
|
54 |
|
55 |
+
_qa_benchmark_dict, _long_doc_benchmark_dict = get_benchmarks_enum('AIR-Bench_24.04')
|
56 |
+
|
57 |
+
BenchmarksQA = Enum('BenchmarksQA', _qa_benchmark_dict)
|
58 |
+
BenchmarksLongDoc = Enum('BenchmarksLongDoc', _long_doc_benchmark_dict)
|
src/utils.py
CHANGED
@@ -6,7 +6,7 @@ from typing import List
|
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
-
from src.benchmarks import
|
10 |
from src.display.formatting import styled_message, styled_error
|
11 |
from src.display.columns import COL_NAME_AVG, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RERANKING_MODEL, COL_NAME_RANK, \
|
12 |
COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS, COLS_QA, TYPES_QA, COLS_LONG_DOC, TYPES_LONG_DOC, \
|
@@ -68,11 +68,11 @@ def get_default_cols(task: str, columns: list=[], add_fix_cols: bool=True) -> li
|
|
68 |
if task == "qa":
|
69 |
cols_list = COLS_QA
|
70 |
types_list = TYPES_QA
|
71 |
-
benchmark_list = [c.col_name for c in
|
72 |
elif task == "long-doc":
|
73 |
cols_list = COLS_LONG_DOC
|
74 |
types_list = TYPES_LONG_DOC
|
75 |
-
benchmark_list = [c.col_name for c in
|
76 |
else:
|
77 |
raise NotImplemented
|
78 |
for col_name, col_type in zip(cols_list, types_list):
|
|
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
+
from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
|
10 |
from src.display.formatting import styled_message, styled_error
|
11 |
from src.display.columns import COL_NAME_AVG, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RERANKING_MODEL, COL_NAME_RANK, \
|
12 |
COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS, COLS_QA, TYPES_QA, COLS_LONG_DOC, TYPES_LONG_DOC, \
|
|
|
68 |
if task == "qa":
|
69 |
cols_list = COLS_QA
|
70 |
types_list = TYPES_QA
|
71 |
+
benchmark_list = [c.value.col_name for c in list(BenchmarksQA)]
|
72 |
elif task == "long-doc":
|
73 |
cols_list = COLS_LONG_DOC
|
74 |
types_list = TYPES_LONG_DOC
|
75 |
+
benchmark_list = [c.value.col_name for c in list(BenchmarksLongDoc)]
|
76 |
else:
|
77 |
raise NotImplemented
|
78 |
for col_name, col_type in zip(cols_list, types_list):
|
tests/src/test_benchmarks.py
CHANGED
@@ -3,6 +3,8 @@ from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
|
|
3 |
|
4 |
def test_qabenchmarks():
|
5 |
print(list(BenchmarksQA))
|
|
|
|
|
6 |
|
7 |
|
8 |
def test_longdocbenchmarks():
|
|
|
3 |
|
4 |
def test_qabenchmarks():
|
5 |
print(list(BenchmarksQA))
|
6 |
+
for benchmark in list(BenchmarksQA):
|
7 |
+
print(benchmark.name, benchmark.metric, benchmark.col_name, benchmark.domain, benchmark.lang, benchmark.task)
|
8 |
|
9 |
|
10 |
def test_longdocbenchmarks():
|