fix(MT-BENCH): Added fix for:
Browse files1. Model type is now fixed at "chat" for MT-BENCH. Pretrained models are not shown or can be selected.
2. Language selection in MT-BENCH tab is limited to the EN,DE,ES,FR,IT
app.py
CHANGED
@@ -101,28 +101,28 @@ with demo:
|
|
101 |
|
102 |
demo.load(
|
103 |
core.update_task_groups_and_fewshot,
|
104 |
-
[gr.State(value=0), fewshot],
|
105 |
-
[shown_tasks, fewshot, selected_tab],
|
106 |
)
|
107 |
fewshot.change(
|
108 |
core.update_task_groups_and_fewshot,
|
109 |
-
[selected_tab, fewshot],
|
110 |
-
[shown_tasks, fewshot, selected_tab],
|
111 |
)
|
112 |
acc.select(
|
113 |
core.update_task_groups_and_fewshot,
|
114 |
-
inputs=[gr.State(value=0), fewshot],
|
115 |
-
outputs=[shown_tasks, fewshot, selected_tab],
|
116 |
)
|
117 |
misc.select(
|
118 |
core.update_task_groups_and_fewshot,
|
119 |
-
inputs=[gr.State(value=1), fewshot],
|
120 |
-
outputs=[shown_tasks, fewshot, selected_tab],
|
121 |
)
|
122 |
mtbench.select(
|
123 |
core.update_task_groups_and_fewshot,
|
124 |
-
inputs=[gr.State(value=2), fewshot],
|
125 |
-
outputs=[shown_tasks, fewshot, selected_tab],
|
126 |
)
|
127 |
for comp, fn in [
|
128 |
(search_bar, "submit"),
|
|
|
101 |
|
102 |
demo.load(
|
103 |
core.update_task_groups_and_fewshot,
|
104 |
+
[gr.State(value=0), model_types, langs_bar,fewshot],
|
105 |
+
[shown_tasks, fewshot, selected_tab, model_types, langs_bar],
|
106 |
)
|
107 |
fewshot.change(
|
108 |
core.update_task_groups_and_fewshot,
|
109 |
+
[selected_tab, model_types, langs_bar, fewshot],
|
110 |
+
[shown_tasks, fewshot, selected_tab, model_types, langs_bar],
|
111 |
)
|
112 |
acc.select(
|
113 |
core.update_task_groups_and_fewshot,
|
114 |
+
inputs=[gr.State(value=0), model_types, langs_bar, fewshot],
|
115 |
+
outputs=[shown_tasks, fewshot, selected_tab, model_types, langs_bar],
|
116 |
)
|
117 |
misc.select(
|
118 |
core.update_task_groups_and_fewshot,
|
119 |
+
inputs=[gr.State(value=1), model_types, langs_bar, fewshot],
|
120 |
+
outputs=[shown_tasks, fewshot, selected_tab, model_types, langs_bar],
|
121 |
)
|
122 |
mtbench.select(
|
123 |
core.update_task_groups_and_fewshot,
|
124 |
+
inputs=[gr.State(value=2), model_types, langs_bar, fewshot],
|
125 |
+
outputs=[shown_tasks, fewshot, selected_tab, model_types, langs_bar],
|
126 |
)
|
127 |
for comp, fn in [
|
128 |
(search_bar, "submit"),
|
core.py
CHANGED
@@ -4,17 +4,17 @@ import os
|
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
-
import plotly.express as px
|
8 |
from datasets import load_dataset
|
9 |
|
10 |
import style
|
|
|
11 |
|
12 |
ZERO_SHOT_ONLY = ["BELEBELE", "MT-Bench"]
|
13 |
FEW_SHOT_ONLY = ["GSM8K", "TruthfulQA"]
|
14 |
|
15 |
|
16 |
def init():
|
17 |
-
global repo_id, config_name, split_name, hidden_df, task_group_names_list, task_group_type_dict, task_groups_shots_dict, languages_list, model_type_dict
|
18 |
|
19 |
repo_id = os.getenv("OGX_LEADERBOARD_DATASET_NAME")
|
20 |
config_name = os.getenv("OGX_LEADERBOARD_DATASET_CONFIG")
|
@@ -29,6 +29,7 @@ def init():
|
|
29 |
task_groups_shots_df = hidden_df[hidden_df["Few_Shot"] == True][["Task_Group", "Number_Shots"]].drop_duplicates()
|
30 |
task_groups_shots_dict = task_groups_shots_df.set_index("Task_Group")["Number_Shots"].to_dict()
|
31 |
languages_list = hidden_df["Language"].drop_duplicates().str.upper().tolist()
|
|
|
32 |
model_type_df = hidden_df[["Model_Name", "Model_Type"]].drop_duplicates()
|
33 |
model_type_dict = model_type_df.set_index("Model_Name")["Model_Type"].to_dict()
|
34 |
|
@@ -127,7 +128,7 @@ def update_df(
|
|
127 |
return sort_cols(df, fewshot)
|
128 |
|
129 |
|
130 |
-
def update_task_groups_and_fewshot(current_selected_tab: int, is_fewshot_current: bool = False):
|
131 |
selected_task_type = get_selected_task_type(current_selected_tab)
|
132 |
available_tasks = get_available_task_groups(selected_task_type, is_fewshot_current)
|
133 |
new_selected_tasks = available_tasks.copy()
|
@@ -154,7 +155,36 @@ def update_task_groups_and_fewshot(current_selected_tab: int, is_fewshot_current
|
|
154 |
interactive=fewshot_available,
|
155 |
)
|
156 |
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
|
160 |
def get_selected_task_type(task_type_id):
|
|
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
|
|
7 |
from datasets import load_dataset
|
8 |
|
9 |
import style
|
10 |
+
from style import T_SYMBOLS, MT_BENCH_LANG_SYMBOLS, LANG_SYMBOLS
|
11 |
|
12 |
ZERO_SHOT_ONLY = ["BELEBELE", "MT-Bench"]
|
13 |
FEW_SHOT_ONLY = ["GSM8K", "TruthfulQA"]
|
14 |
|
15 |
|
16 |
def init():
|
17 |
+
global repo_id, config_name, split_name, hidden_df, task_group_names_list, task_group_type_dict, task_groups_shots_dict, languages_list, model_type_dict, mt_bench_language_list
|
18 |
|
19 |
repo_id = os.getenv("OGX_LEADERBOARD_DATASET_NAME")
|
20 |
config_name = os.getenv("OGX_LEADERBOARD_DATASET_CONFIG")
|
|
|
29 |
task_groups_shots_df = hidden_df[hidden_df["Few_Shot"] == True][["Task_Group", "Number_Shots"]].drop_duplicates()
|
30 |
task_groups_shots_dict = task_groups_shots_df.set_index("Task_Group")["Number_Shots"].to_dict()
|
31 |
languages_list = hidden_df["Language"].drop_duplicates().str.upper().tolist()
|
32 |
+
mt_bench_language_list = hidden_df[hidden_df['Task_Group'] == "MTBENCH"]["Language"].drop_duplicates().str.upper().tolist()
|
33 |
model_type_df = hidden_df[["Model_Name", "Model_Type"]].drop_duplicates()
|
34 |
model_type_dict = model_type_df.set_index("Model_Name")["Model_Type"].to_dict()
|
35 |
|
|
|
128 |
return sort_cols(df, fewshot)
|
129 |
|
130 |
|
131 |
+
def update_task_groups_and_fewshot(current_selected_tab: int, model_types, langs_bar, is_fewshot_current: bool = False, ):
|
132 |
selected_task_type = get_selected_task_type(current_selected_tab)
|
133 |
available_tasks = get_available_task_groups(selected_task_type, is_fewshot_current)
|
134 |
new_selected_tasks = available_tasks.copy()
|
|
|
155 |
interactive=fewshot_available,
|
156 |
)
|
157 |
|
158 |
+
if current_selected_tab == 2:
|
159 |
+
model_types = gr.CheckboxGroup(
|
160 |
+
value=[T_SYMBOLS['chat']],
|
161 |
+
interactive=False
|
162 |
+
)
|
163 |
+
langs_bar = gr.CheckboxGroup(
|
164 |
+
choices=[(MT_BENCH_LANG_SYMBOLS.get(l, l), l) for l in mt_bench_language_list],
|
165 |
+
value=mt_bench_language_list,
|
166 |
+
interactive=True,
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
model_types = gr.CheckboxGroup(
|
170 |
+
label="Select model type",
|
171 |
+
choices=[
|
172 |
+
(
|
173 |
+
f"Pretrained {T_SYMBOLS['pretrained']}",
|
174 |
+
T_SYMBOLS["pretrained"],
|
175 |
+
),
|
176 |
+
(f"Chat {T_SYMBOLS['chat']}", T_SYMBOLS["chat"]),
|
177 |
+
],
|
178 |
+
value=list(T_SYMBOLS.values()),
|
179 |
+
interactive=True
|
180 |
+
)
|
181 |
+
langs_bar = gr.CheckboxGroup(
|
182 |
+
choices=[(LANG_SYMBOLS.get(l, l), l) for l in languages_list],
|
183 |
+
value=languages_list,
|
184 |
+
interactive=True,
|
185 |
+
)
|
186 |
+
|
187 |
+
return [tasks_checkbox_group_update, fewshot_radio_update, current_selected_tab, model_types, langs_bar]
|
188 |
|
189 |
|
190 |
def get_selected_task_type(task_type_id):
|
style.py
CHANGED
@@ -40,3 +40,10 @@ LANG_SYMBOLS = {
|
|
40 |
"SV": "๐ธ๐ช SV"
|
41 |
}
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
"SV": "๐ธ๐ช SV"
|
41 |
}
|
42 |
|
43 |
+
MT_BENCH_LANG_SYMBOLS = {
|
44 |
+
"ES": "๐ช๐ธ ES",
|
45 |
+
"EN": "๐ฌ๐ง EN",
|
46 |
+
"DE": "๐ฉ๐ช DE",
|
47 |
+
"FR": "๐ซ๐ท FR",
|
48 |
+
"IT": "๐ฎ๐น IT"
|
49 |
+
}
|