Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
test: add unit tests for models
Browse files- src/models.py +10 -9
- tests/src/test_models.py +33 -4
src/models.py
CHANGED
@@ -109,18 +109,19 @@ class FullEvalResult:
|
|
109 |
continue
|
110 |
if eval_result.task != task:
|
111 |
continue
|
112 |
-
|
113 |
-
results[
|
|
|
114 |
self.retrieval_model, self.retrieval_model_link
|
115 |
)
|
116 |
-
results[
|
117 |
self.reranking_model, self.reranking_model_link
|
118 |
)
|
119 |
-
results[
|
120 |
-
results[
|
121 |
-
results[
|
122 |
-
results[
|
123 |
-
results[
|
124 |
|
125 |
for result in eval_result.results:
|
126 |
# add result for each domain, language, and dataset
|
@@ -132,7 +133,7 @@ class FullEvalResult:
|
|
132 |
benchmark_name = f"{domain}_{lang}"
|
133 |
else:
|
134 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
135 |
-
results[
|
136 |
return [v for v in results.values()]
|
137 |
|
138 |
|
|
|
109 |
continue
|
110 |
if eval_result.task != task:
|
111 |
continue
|
112 |
+
eval_name = eval_result.eval_name
|
113 |
+
results[eval_name]["eval_name"] = eval_name
|
114 |
+
results[eval_name][COL_NAME_RETRIEVAL_MODEL] = make_clickable_model(
|
115 |
self.retrieval_model, self.retrieval_model_link
|
116 |
)
|
117 |
+
results[eval_name][COL_NAME_RERANKING_MODEL] = make_clickable_model(
|
118 |
self.reranking_model, self.reranking_model_link
|
119 |
)
|
120 |
+
results[eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
|
121 |
+
results[eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
|
122 |
+
results[eval_name][COL_NAME_REVISION] = self.revision
|
123 |
+
results[eval_name][COL_NAME_TIMESTAMP] = self.timestamp
|
124 |
+
results[eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
|
125 |
|
126 |
for result in eval_result.results:
|
127 |
# add result for each domain, language, and dataset
|
|
|
133 |
benchmark_name = f"{domain}_{lang}"
|
134 |
else:
|
135 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
136 |
+
results[eval_name][get_safe_name(benchmark_name)] = value
|
137 |
return [v for v in results.values()]
|
138 |
|
139 |
|
tests/src/test_models.py
CHANGED
@@ -6,6 +6,23 @@ from src.models import EvalResult, FullEvalResult
|
|
6 |
cur_fp = Path(__file__)
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def test_eval_result():
|
10 |
eval_result = EvalResult(
|
11 |
eval_name="eval_name",
|
@@ -41,9 +58,21 @@ def test_full_eval_result_init_from_json_file(file_path):
|
|
41 |
assert len(full_eval_result.results) == 70
|
42 |
|
43 |
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
47 |
-
result_dict_list = full_eval_result.to_dict()
|
48 |
assert len(result_dict_list) == 1
|
49 |
-
|
|
|
|
|
|
|
|
|
|
6 |
cur_fp = Path(__file__)
|
7 |
|
8 |
|
9 |
+
# Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
|
10 |
+
# 24.05
|
11 |
+
# | Task | dev | test |
|
12 |
+
# | ---- | --- | ---- |
|
13 |
+
# | Long-Doc | 4 | 11 |
|
14 |
+
# | QA | 54 | 53 |
|
15 |
+
#
|
16 |
+
# 24.04
|
17 |
+
# | Task | test |
|
18 |
+
# | ---- | ---- |
|
19 |
+
# | Long-Doc | 15 |
|
20 |
+
# | QA | 13 |
|
21 |
+
NUM_QA_BENCHMARKS_24_05 = 53
|
22 |
+
NUM_DOC_BENCHMARKS_24_05 = 11
|
23 |
+
NUM_QA_BENCHMARKS_24_04 = 13
|
24 |
+
NUM_DOC_BENCHMARKS_24_04 = 15
|
25 |
+
|
26 |
def test_eval_result():
|
27 |
eval_result = EvalResult(
|
28 |
eval_name="eval_name",
|
|
|
58 |
assert len(full_eval_result.results) == 70
|
59 |
|
60 |
|
61 |
+
@pytest.mark.parametrize(
|
62 |
+
'file_path, task, expected_num_results',
|
63 |
+
[
|
64 |
+
("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "qa", NUM_QA_BENCHMARKS_24_04),
|
65 |
+
("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_04),
|
66 |
+
("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "qa", NUM_QA_BENCHMARKS_24_05),
|
67 |
+
("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_05),
|
68 |
+
])
|
69 |
+
def test_full_eval_result_to_dict(file_path, task, expected_num_results):
|
70 |
+
json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
|
71 |
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
72 |
+
result_dict_list = full_eval_result.to_dict(task)
|
73 |
assert len(result_dict_list) == 1
|
74 |
+
result = result_dict_list[0]
|
75 |
+
attr_list = frozenset([
|
76 |
+
'eval_name', 'Retrieval Method', 'Reranking Model', 'Retrieval Model LINK', 'Reranking Model LINK', 'Revision', 'Submission Date', 'Anonymous Submission'])
|
77 |
+
result_cols = list(result.keys())
|
78 |
+
assert len(result_cols) == (expected_num_results + len(attr_list))
|