Spaces:
AIR-Bench
/
Running on CPU Upgrade

nan commited on
Commit
1a22df4
1 Parent(s): a3fa5e4

refactor: remove the legacy imports

Browse files
app.py CHANGED
@@ -87,7 +87,6 @@ def update_metric_long_doc(
87
 
88
 
89
  def update_datastore(version):
90
- print("triggered update_datastore")
91
  global datastore
92
  global data
93
  datastore = data[version]
@@ -104,7 +103,6 @@ def update_datastore(version):
104
  def update_datastore_long_doc(version):
105
  global datastore
106
  global data
107
- print("triggered update_datastore_long_doc")
108
  datastore = data[version]
109
  selected_domains = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
110
  selected_langs = get_language_dropdown(LongDocBenchmarks[datastore.slug])
@@ -336,12 +334,11 @@ with demo:
336
  show_anonymous = get_anonymous_checkbox()
337
  with gr.Row():
338
  show_revision_and_timestamp = get_revision_and_ts_checkbox()
339
- with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
340
  with gr.TabItem("Retrieval + Reranking", id=20):
341
  with gr.Row():
342
  with gr.Column():
343
  search_bar = get_search_bar()
344
- # select reranking model
345
  with gr.Column():
346
  selected_rerankings = get_reranking_dropdown(datastore.reranking_models)
347
 
 
87
 
88
 
89
  def update_datastore(version):
 
90
  global datastore
91
  global data
92
  datastore = data[version]
 
103
  def update_datastore_long_doc(version):
104
  global datastore
105
  global data
 
106
  datastore = data[version]
107
  selected_domains = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
108
  selected_langs = get_language_dropdown(LongDocBenchmarks[datastore.slug])
 
334
  show_anonymous = get_anonymous_checkbox()
335
  with gr.Row():
336
  show_revision_and_timestamp = get_revision_and_ts_checkbox()
337
+ with gr.Tabs(elem_classes="tab-buttons"):
338
  with gr.TabItem("Retrieval + Reranking", id=20):
339
  with gr.Row():
340
  with gr.Column():
341
  search_bar = get_search_bar()
 
342
  with gr.Column():
343
  selected_rerankings = get_reranking_dropdown(datastore.reranking_models)
344
 
src/benchmarks.py CHANGED
@@ -3,7 +3,7 @@ from enum import Enum
3
 
4
  from air_benchmark.tasks.tasks import BenchmarkTable
5
 
6
- from src.envs import METRIC_LIST
7
 
8
 
9
  def get_safe_name(name: str):
@@ -59,19 +59,16 @@ def get_benchmarks_enum(benchmark_version, task_type):
59
  return benchmark_dict
60
 
61
 
62
- versions = ("AIR-Bench_24.04", "AIR-Bench_24.05")
63
  qa_benchmark_dict = {}
64
- for version in versions:
65
  safe_version_name = get_safe_name(version)[-4:]
66
  qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, "qa"))
67
 
68
  long_doc_benchmark_dict = {}
69
- for version in versions:
70
  safe_version_name = get_safe_name(version)[-4:]
71
  long_doc_benchmark_dict[safe_version_name] = Enum(f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, "long-doc"))
72
 
73
- # _qa_benchmark_dict, = get_benchmarks_enum('AIR-Bench_24.04', "qa")
74
- # _long_doc_benchmark_dict = get_benchmarks_enum('AIR-Bench_24.04', "long-doc")
75
 
76
  QABenchmarks = Enum('QABenchmarks', qa_benchmark_dict)
77
  LongDocBenchmarks = Enum('LongDocBenchmarks', long_doc_benchmark_dict)
 
3
 
4
  from air_benchmark.tasks.tasks import BenchmarkTable
5
 
6
+ from src.envs import METRIC_LIST, BENCHMARK_VERSION_LIST
7
 
8
 
9
  def get_safe_name(name: str):
 
59
  return benchmark_dict
60
 
61
 
 
62
  qa_benchmark_dict = {}
63
+ for version in BENCHMARK_VERSION_LIST:
64
  safe_version_name = get_safe_name(version)[-4:]
65
  qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_benchmarks_enum(version, "qa"))
66
 
67
  long_doc_benchmark_dict = {}
68
+ for version in BENCHMARK_VERSION_LIST:
69
  safe_version_name = get_safe_name(version)[-4:]
70
  long_doc_benchmark_dict[safe_version_name] = Enum(f"LongDocBenchmarks_{safe_version_name}", get_benchmarks_enum(version, "long-doc"))
71
 
 
 
72
 
73
  QABenchmarks = Enum('QABenchmarks', qa_benchmark_dict)
74
  LongDocBenchmarks = Enum('LongDocBenchmarks', long_doc_benchmark_dict)
src/display/columns.py CHANGED
@@ -1,6 +1,5 @@
1
  from dataclasses import dataclass, make_dataclass
2
 
3
- from src.benchmarks import QABenchmarks, LongDocBenchmarks
4
  from src.envs import COL_NAME_AVG, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL_LINK, \
5
  COL_NAME_RERANKING_MODEL_LINK, COL_NAME_RANK, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS
6
 
@@ -76,22 +75,7 @@ def get_default_col_names_and_types(benchmarks):
76
  col_types = [c.type for c in fields(AutoEvalColumn) if not c.hidden]
77
  return col_names, col_types
78
 
79
- # AutoEvalColumnQA = make_autoevalcolumn("AutoEvalColumnQA", QABenchmarks)
80
- # COLS_QA = [c.name for c in fields(AutoEvalColumnQA) if not c.hidden]
81
- # TYPES_QA = [c.type for c in fields(AutoEvalColumnQA) if not c.hidden]
82
-
83
 
84
  def get_fixed_col_names_and_types():
85
  fixed_cols = get_default_auto_eval_column_dict()[:-3]
86
  return [c.name for _, _, c in fixed_cols], [c.type for _, _, c in fixed_cols]
87
-
88
- # fixed_cols = get_default_auto_eval_column_dict()[:-3]
89
- # FIXED_COLS = [c.name for _, _, c in fixed_cols]
90
- # FIXED_COLS_TYPES = [c.type for _, _, c in fixed_cols]
91
-
92
-
93
- # AutoEvalColumnLongDoc = make_autoevalcolumn("AutoEvalColumnLongDoc", LongDocBenchmarks)
94
- # COLS_LONG_DOC = [c.name for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
95
- # TYPES_LONG_DOC = [c.type for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
96
-
97
- # Column selection
 
1
  from dataclasses import dataclass, make_dataclass
2
 
 
3
  from src.envs import COL_NAME_AVG, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL_LINK, \
4
  COL_NAME_RERANKING_MODEL_LINK, COL_NAME_RANK, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS
5
 
 
75
  col_types = [c.type for c in fields(AutoEvalColumn) if not c.hidden]
76
  return col_names, col_types
77
 
 
 
 
 
78
 
79
  def get_fixed_col_names_and_types():
80
  fixed_cols = get_default_auto_eval_column_dict()[:-3]
81
  return [c.name for _, _, c in fixed_cols], [c.type for _, _, c in fixed_cols]
 
 
 
 
 
 
 
 
 
 
 
src/display/components.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
 
2
  from src.envs import BENCHMARK_VERSION_LIST, LATEST_BENCHMARK_VERSION
3
- from src.benchmarks import QABenchmarks
4
 
5
  def get_version_dropdown():
6
  return gr.Dropdown(
 
1
  import gradio as gr
2
+
3
  from src.envs import BENCHMARK_VERSION_LIST, LATEST_BENCHMARK_VERSION
4
+
5
 
6
  def get_version_dropdown():
7
  return gr.Dropdown(
src/envs.py CHANGED
@@ -27,7 +27,7 @@ BM25_LINK = model_hyperlink("https://github.com/castorini/pyserini", "BM25")
27
 
28
  BENCHMARK_VERSION_LIST = [
29
  "AIR-Bench_24.04",
30
- # "AIR-Bench_24.05",
31
  ]
32
 
33
  LATEST_BENCHMARK_VERSION = BENCHMARK_VERSION_LIST[0]
 
27
 
28
  BENCHMARK_VERSION_LIST = [
29
  "AIR-Bench_24.04",
30
+ "AIR-Bench_24.05",
31
  ]
32
 
33
  LATEST_BENCHMARK_VERSION = BENCHMARK_VERSION_LIST[0]
src/loaders.py CHANGED
@@ -5,7 +5,6 @@ import pandas as pd
5
 
6
  from src.envs import DEFAULT_METRIC_QA, DEFAULT_METRIC_LONG_DOC, COL_NAME_REVISION, COL_NAME_TIMESTAMP, \
7
  COL_NAME_IS_ANONYMOUS, BENCHMARK_VERSION_LIST
8
-
9
  from src.models import FullEvalResult, LeaderboardDataStore
10
  from src.utils import get_default_cols, get_leaderboard_df
11
 
@@ -50,6 +49,7 @@ def load_raw_eval_results(results_path: str) -> List[FullEvalResult]:
50
  continue
51
  return results
52
 
 
53
  def get_safe_name(name: str):
54
  """Get RFC 1123 compatible safe name"""
55
  name = name.replace('-', '_')
@@ -58,6 +58,7 @@ def get_safe_name(name: str):
58
  for character in name
59
  if (character.isalnum() or character == '_'))
60
 
 
61
  def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
62
  slug = get_safe_name(version)[-4:]
63
  lb_data_store = LeaderboardDataStore(version, slug, None, None, None, None, None, None, None, None)
@@ -69,8 +70,6 @@ def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
69
  print(f'QA data loaded: {lb_data_store.raw_df_qa.shape}')
70
  lb_data_store.leaderboard_df_qa = lb_data_store.raw_df_qa.copy()
71
  shown_columns_qa, types_qa = get_default_cols('qa', lb_data_store.slug, add_fix_cols=True)
72
- # shown_columns_qa, types_qa = get_default_cols(
73
- # 'qa', lb_data_store.leaderboard_df_qa.columns, add_fix_cols=True)
74
  lb_data_store.types_qa = types_qa
75
  lb_data_store.leaderboard_df_qa = \
76
  lb_data_store.leaderboard_df_qa[~lb_data_store.leaderboard_df_qa[COL_NAME_IS_ANONYMOUS]][shown_columns_qa]
@@ -95,7 +94,6 @@ def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
95
 
96
  def load_eval_results(file_path: str):
97
  output = {}
98
- # versions = BENCHMARK_VERSION_LIST
99
  for version in BENCHMARK_VERSION_LIST:
100
  fn = f"{file_path}/{version}"
101
  output[version] = load_leaderboard_datastore(fn, version)
 
5
 
6
  from src.envs import DEFAULT_METRIC_QA, DEFAULT_METRIC_LONG_DOC, COL_NAME_REVISION, COL_NAME_TIMESTAMP, \
7
  COL_NAME_IS_ANONYMOUS, BENCHMARK_VERSION_LIST
 
8
  from src.models import FullEvalResult, LeaderboardDataStore
9
  from src.utils import get_default_cols, get_leaderboard_df
10
 
 
49
  continue
50
  return results
51
 
52
+
53
  def get_safe_name(name: str):
54
  """Get RFC 1123 compatible safe name"""
55
  name = name.replace('-', '_')
 
58
  for character in name
59
  if (character.isalnum() or character == '_'))
60
 
61
+
62
  def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
63
  slug = get_safe_name(version)[-4:]
64
  lb_data_store = LeaderboardDataStore(version, slug, None, None, None, None, None, None, None, None)
 
70
  print(f'QA data loaded: {lb_data_store.raw_df_qa.shape}')
71
  lb_data_store.leaderboard_df_qa = lb_data_store.raw_df_qa.copy()
72
  shown_columns_qa, types_qa = get_default_cols('qa', lb_data_store.slug, add_fix_cols=True)
 
 
73
  lb_data_store.types_qa = types_qa
74
  lb_data_store.leaderboard_df_qa = \
75
  lb_data_store.leaderboard_df_qa[~lb_data_store.leaderboard_df_qa[COL_NAME_IS_ANONYMOUS]][shown_columns_qa]
 
94
 
95
  def load_eval_results(file_path: str):
96
  output = {}
 
97
  for version in BENCHMARK_VERSION_LIST:
98
  fn = f"{file_path}/{version}"
99
  output[version] = load_leaderboard_datastore(fn, version)
src/models.py CHANGED
@@ -6,9 +6,9 @@ from typing import List, Optional
6
  import pandas as pd
7
 
8
  from src.benchmarks import get_safe_name
 
9
  from src.envs import COL_NAME_RETRIEVAL_MODEL, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL_LINK, \
10
  COL_NAME_RERANKING_MODEL_LINK, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS
11
- from src.display.formatting import make_clickable_model
12
 
13
 
14
  @dataclass
@@ -92,7 +92,8 @@ class FullEvalResult:
92
 
93
  def to_dict(self, task='qa', metric='ndcg_at_3') -> List:
94
  """
95
- Convert the results in all the EvalResults over different tasks and metrics. The output is a list of dict compatible with the dataframe UI
 
96
  """
97
  results = defaultdict(dict)
98
  for eval_result in self.results:
@@ -111,7 +112,6 @@ class FullEvalResult:
111
  results[eval_result.eval_name][COL_NAME_TIMESTAMP] = self.timestamp
112
  results[eval_result.eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
113
 
114
- # print(f'result loaded: {eval_result.eval_name}')
115
  for result in eval_result.results:
116
  # add result for each domain, language, and dataset
117
  domain = result["domain"]
 
6
  import pandas as pd
7
 
8
  from src.benchmarks import get_safe_name
9
+ from src.display.formatting import make_clickable_model
10
  from src.envs import COL_NAME_RETRIEVAL_MODEL, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL_LINK, \
11
  COL_NAME_RERANKING_MODEL_LINK, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS
 
12
 
13
 
14
  @dataclass
 
92
 
93
  def to_dict(self, task='qa', metric='ndcg_at_3') -> List:
94
  """
95
+ Convert the results in all the EvalResults over different tasks and metrics.
96
+ The output is a list of dict compatible with the dataframe UI
97
  """
98
  results = defaultdict(dict)
99
  for eval_result in self.results:
 
112
  results[eval_result.eval_name][COL_NAME_TIMESTAMP] = self.timestamp
113
  results[eval_result.eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
114
 
 
115
  for result in eval_result.results:
116
  # add result for each domain, language, and dataset
117
  domain = result["domain"]
src/utils.py CHANGED
@@ -1,18 +1,17 @@
1
- import json
2
  import hashlib
 
 
3
  from datetime import datetime, timezone
4
  from pathlib import Path
5
 
6
  import pandas as pd
7
 
8
  from src.benchmarks import QABenchmarks, LongDocBenchmarks
9
- from src.display.formatting import styled_message, styled_error
10
  from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
 
11
  from src.envs import API, SEARCH_RESULTS_REPO, LATEST_BENCHMARK_VERSION, COL_NAME_AVG, COL_NAME_RETRIEVAL_MODEL, \
12
  COL_NAME_RERANKING_MODEL, COL_NAME_RANK, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS
13
 
14
- import re
15
-
16
 
17
  def calculate_mean(row):
18
  if pd.isna(row).any():
@@ -20,6 +19,7 @@ def calculate_mean(row):
20
  else:
21
  return row.mean()
22
 
 
23
  def remove_html(input_str):
24
  # Regular expression for finding HTML tags
25
  clean = re.sub(r'<.*?>', '', input_str)
@@ -59,7 +59,7 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
59
  return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
60
 
61
 
62
- def get_default_cols(task: str, version_slug, add_fix_cols: bool=True) -> tuple:
63
  cols = []
64
  types = []
65
  if task == "qa":
@@ -105,6 +105,8 @@ def select_columns(
105
  eval_col = QABenchmarks[version_slug].value[c].value
106
  elif task == "long-doc":
107
  eval_col = LongDocBenchmarks[version_slug].value[c].value
 
 
108
  if eval_col.domain not in domain_query:
109
  continue
110
  if eval_col.lang not in language_query:
@@ -122,6 +124,7 @@ def select_columns(
122
 
123
  return filtered_df
124
 
 
125
  def get_safe_name(name: str):
126
  """Get RFC 1123 compatible safe name"""
127
  name = name.replace('-', '_')
@@ -130,6 +133,7 @@ def get_safe_name(name: str):
130
  for character in name
131
  if (character.isalnum() or character == '_'))
132
 
 
133
  def _update_table(
134
  task: str,
135
  version: str,
@@ -249,9 +253,9 @@ def submit_results(
249
  filepath: str,
250
  model: str,
251
  model_url: str,
252
- reranking_model: str="",
253
- reranking_model_url: str="",
254
- version: str=LATEST_BENCHMARK_VERSION,
255
  is_anonymous=False):
256
  if not filepath.endswith(".zip"):
257
  return styled_error(f"file uploading aborted. wrong file type: {filepath}")
@@ -280,7 +284,7 @@ def submit_results(
280
 
281
  if not reranking_model:
282
  reranking_model = 'NoReranker'
283
-
284
  API.upload_file(
285
  path_or_fileobj=filepath,
286
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
@@ -384,14 +388,15 @@ def set_listeners(
384
  search_bar,
385
  show_anonymous
386
  ]
387
- search_bar_args = [source_df, version,] + selector_list
388
- selector_args = [version, source_df] + selector_list + [show_revision_and_timestamp,]
389
  # Set search_bar listener
390
  search_bar.submit(update_table_func, search_bar_args, target_df)
391
 
392
  # Set column-wise listener
393
  for selector in selector_list:
394
- selector.change(update_table_func, selector_args, target_df, queue=True,)
 
395
 
396
  def update_table(
397
  version: str,
 
 
1
  import hashlib
2
+ import json
3
+ import re
4
  from datetime import datetime, timezone
5
  from pathlib import Path
6
 
7
  import pandas as pd
8
 
9
  from src.benchmarks import QABenchmarks, LongDocBenchmarks
 
10
  from src.display.columns import get_default_col_names_and_types, get_fixed_col_names_and_types
11
+ from src.display.formatting import styled_message, styled_error
12
  from src.envs import API, SEARCH_RESULTS_REPO, LATEST_BENCHMARK_VERSION, COL_NAME_AVG, COL_NAME_RETRIEVAL_MODEL, \
13
  COL_NAME_RERANKING_MODEL, COL_NAME_RANK, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_IS_ANONYMOUS
14
 
 
 
15
 
16
  def calculate_mean(row):
17
  if pd.isna(row).any():
 
19
  else:
20
  return row.mean()
21
 
22
+
23
  def remove_html(input_str):
24
  # Regular expression for finding HTML tags
25
  clean = re.sub(r'<.*?>', '', input_str)
 
59
  return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
60
 
61
 
62
+ def get_default_cols(task: str, version_slug, add_fix_cols: bool = True) -> tuple:
63
  cols = []
64
  types = []
65
  if task == "qa":
 
105
  eval_col = QABenchmarks[version_slug].value[c].value
106
  elif task == "long-doc":
107
  eval_col = LongDocBenchmarks[version_slug].value[c].value
108
+ else:
109
+ raise NotImplemented
110
  if eval_col.domain not in domain_query:
111
  continue
112
  if eval_col.lang not in language_query:
 
124
 
125
  return filtered_df
126
 
127
+
128
  def get_safe_name(name: str):
129
  """Get RFC 1123 compatible safe name"""
130
  name = name.replace('-', '_')
 
133
  for character in name
134
  if (character.isalnum() or character == '_'))
135
 
136
+
137
  def _update_table(
138
  task: str,
139
  version: str,
 
253
  filepath: str,
254
  model: str,
255
  model_url: str,
256
+ reranking_model: str = "",
257
+ reranking_model_url: str = "",
258
+ version: str = LATEST_BENCHMARK_VERSION,
259
  is_anonymous=False):
260
  if not filepath.endswith(".zip"):
261
  return styled_error(f"file uploading aborted. wrong file type: {filepath}")
 
284
 
285
  if not reranking_model:
286
  reranking_model = 'NoReranker'
287
+
288
  API.upload_file(
289
  path_or_fileobj=filepath,
290
  path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
 
388
  search_bar,
389
  show_anonymous
390
  ]
391
+ search_bar_args = [source_df, version, ] + selector_list
392
+ selector_args = [version, source_df] + selector_list + [show_revision_and_timestamp, ]
393
  # Set search_bar listener
394
  search_bar.submit(update_table_func, search_bar_args, target_df)
395
 
396
  # Set column-wise listener
397
  for selector in selector_list:
398
+ selector.change(update_table_func, selector_args, target_df, queue=True, )
399
+
400
 
401
  def update_table(
402
  version: str,