Spaces:
AIR-Bench
/
Running on CPU Upgrade

nan commited on
Commit
3fcf957
1 Parent(s): 5e03e4a

refactor: refactor the benchmarks

Browse files
Files changed (4) hide show
  1. app.py +6 -6
  2. src/benchmarks.py +27 -18
  3. src/utils.py +3 -3
  4. tests/src/test_benchmarks.py +2 -0
app.py CHANGED
@@ -6,8 +6,8 @@ from src.about import (
6
  TITLE
7
  )
8
  from src.benchmarks import (
9
- qa_benchmark_dict,
10
- long_doc_benchmark_dict
11
  )
12
  from src.display.css_html_js import custom_css
13
  from src.envs import (
@@ -76,11 +76,11 @@ def update_metric_long_doc(
76
  return update_metric(data["AIR-Bench_24.04"].raw_data, "long-doc", metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
77
 
78
 
79
- DOMAIN_COLS_QA = list(frozenset([c.domain for c in qa_benchmark_dict.values()]))
80
- LANG_COLS_QA = list(frozenset([c.lang for c in qa_benchmark_dict.values()]))
81
 
82
- DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in long_doc_benchmark_dict.values()]))
83
- LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in long_doc_benchmark_dict.values()]))
84
 
85
  demo = gr.Blocks(css=custom_css)
86
 
 
6
  TITLE
7
  )
8
  from src.benchmarks import (
9
+ BenchmarksQA,
10
+ BenchmarksLongDoc
11
  )
12
  from src.display.css_html_js import custom_css
13
  from src.envs import (
 
76
  return update_metric(data["AIR-Bench_24.04"].raw_data, "long-doc", metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
77
 
78
 
79
+ DOMAIN_COLS_QA = list(frozenset([c.value.domain for c in list(BenchmarksQA)]))
80
+ LANG_COLS_QA = list(frozenset([c.value.lang for c in list(BenchmarksQA)]))
81
 
82
+ DOMAIN_COLS_LONG_DOC = list(frozenset([c.value.domain for c in list(BenchmarksLongDoc)]))
83
+ LANG_COLS_LONG_DOC = list(frozenset([c.value.lang for c in list(BenchmarksLongDoc)]))
84
 
85
  demo = gr.Blocks(css=custom_css)
86
 
src/benchmarks.py CHANGED
@@ -25,25 +25,34 @@ class Benchmark:
25
  task: str
26
 
27
 
28
- qa_benchmark_dict = {}
29
- long_doc_benchmark_dict = {}
30
- for task, domain_dict in BenchmarkTable['AIR-Bench_24.04'].items():
31
- for domain, lang_dict in domain_dict.items():
32
- for lang, dataset_list in lang_dict.items():
33
- if task == "qa":
34
- benchmark_name = f"{domain}_{lang}"
35
- benchmark_name = get_safe_name(benchmark_name)
36
- col_name = benchmark_name
37
- for metric in dataset_list:
38
- qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
39
- elif task == "long-doc":
40
- for dataset in dataset_list:
41
- benchmark_name = f"{domain}_{lang}_{dataset}"
42
  benchmark_name = get_safe_name(benchmark_name)
43
  col_name = benchmark_name
44
- for metric in METRIC_LIST:
45
- long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain,
 
 
 
 
 
 
 
 
 
 
 
46
  lang, task)
 
47
 
48
- BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
49
- BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)
 
 
 
25
  task: str
26
 
27
 
28
+ # create a function return an enum class containing all the benchmarks
29
+ def get_benchmarks_enum(benchmark_version):
30
+ qa_benchmark_dict = {}
31
+ long_doc_benchmark_dict = {}
32
+ for task, domain_dict in BenchmarkTable[benchmark_version].items():
33
+ for domain, lang_dict in domain_dict.items():
34
+ for lang, dataset_list in lang_dict.items():
35
+ if task == "qa":
36
+ benchmark_name = f"{domain}_{lang}"
 
 
 
 
 
37
  benchmark_name = get_safe_name(benchmark_name)
38
  col_name = benchmark_name
39
+ for metric in dataset_list:
40
+ qa_benchmark_dict[benchmark_name] = \
41
+ Benchmark(
42
+ benchmark_name, metric, col_name, domain, lang, task)
43
+ elif task == "long-doc":
44
+ for dataset in dataset_list:
45
+ benchmark_name = f"{domain}_{lang}_{dataset}"
46
+ benchmark_name = get_safe_name(benchmark_name)
47
+ col_name = benchmark_name
48
+ for metric in METRIC_LIST:
49
+ long_doc_benchmark_dict[benchmark_name] = \
50
+ Benchmark(
51
+ benchmark_name, metric, col_name, domain,
52
  lang, task)
53
+ return qa_benchmark_dict, long_doc_benchmark_dict
54
 
55
+ _qa_benchmark_dict, _long_doc_benchmark_dict = get_benchmarks_enum('AIR-Bench_24.04')
56
+
57
+ BenchmarksQA = Enum('BenchmarksQA', _qa_benchmark_dict)
58
+ BenchmarksLongDoc = Enum('BenchmarksLongDoc', _long_doc_benchmark_dict)
src/utils.py CHANGED
@@ -6,7 +6,7 @@ from typing import List
6
 
7
  import pandas as pd
8
 
9
- from src.benchmarks import qa_benchmark_dict, long_doc_benchmark_dict, BenchmarksQA, BenchmarksLongDoc
10
  from src.display.formatting import styled_message, styled_error
11
  from src.display.columns import COL_NAME_AVG, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RERANKING_MODEL, COL_NAME_RANK, \
12
  COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS, COLS_QA, TYPES_QA, COLS_LONG_DOC, TYPES_LONG_DOC, \
@@ -68,11 +68,11 @@ def get_default_cols(task: str, columns: list=[], add_fix_cols: bool=True) -> li
68
  if task == "qa":
69
  cols_list = COLS_QA
70
  types_list = TYPES_QA
71
- benchmark_list = [c.col_name for c in qa_benchmark_dict.values()]
72
  elif task == "long-doc":
73
  cols_list = COLS_LONG_DOC
74
  types_list = TYPES_LONG_DOC
75
- benchmark_list = [c.col_name for c in long_doc_benchmark_dict.values()]
76
  else:
77
  raise NotImplemented
78
  for col_name, col_type in zip(cols_list, types_list):
 
6
 
7
  import pandas as pd
8
 
9
+ from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
10
  from src.display.formatting import styled_message, styled_error
11
  from src.display.columns import COL_NAME_AVG, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RERANKING_MODEL, COL_NAME_RANK, \
12
  COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS, COLS_QA, TYPES_QA, COLS_LONG_DOC, TYPES_LONG_DOC, \
 
68
  if task == "qa":
69
  cols_list = COLS_QA
70
  types_list = TYPES_QA
71
+ benchmark_list = [c.value.col_name for c in list(BenchmarksQA)]
72
  elif task == "long-doc":
73
  cols_list = COLS_LONG_DOC
74
  types_list = TYPES_LONG_DOC
75
+ benchmark_list = [c.value.col_name for c in list(BenchmarksLongDoc)]
76
  else:
77
  raise NotImplemented
78
  for col_name, col_type in zip(cols_list, types_list):
tests/src/test_benchmarks.py CHANGED
@@ -3,6 +3,8 @@ from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
3
 
4
  def test_qabenchmarks():
5
  print(list(BenchmarksQA))
 
 
6
 
7
 
8
  def test_longdocbenchmarks():
 
3
 
4
  def test_qabenchmarks():
5
  print(list(BenchmarksQA))
6
+ for benchmark in list(BenchmarksQA):
7
+ print(benchmark.name, benchmark.metric, benchmark.col_name, benchmark.domain, benchmark.lang, benchmark.task)
8
 
9
 
10
  def test_longdocbenchmarks():