Upload from GitHub Actions: TruthfulQA translation WIP
Browse files- evals/datasets_/arc.py +1 -0
- evals/datasets_/truthfulqa.py +72 -0
- evals/plots.py +2 -3
evals/datasets_/arc.py
CHANGED
@@ -54,6 +54,7 @@ def load_uhura_arc_easy(language_bcp_47, nr):
|
|
54 |
ds = ds.rename_column("answerKey", "answer")
|
55 |
train_ids = common_ids_train[nr:nr+3]
|
56 |
examples = ds["train"].filter(lambda x: x["id"] in train_ids)
|
|
|
57 |
task = ds["test"].filter(lambda x: x["id"] == common_ids_test[nr])[0]
|
58 |
return "fair-forward/arc-easy-autotranslated", examples, task
|
59 |
else:
|
|
|
54 |
ds = ds.rename_column("answerKey", "answer")
|
55 |
train_ids = common_ids_train[nr:nr+3]
|
56 |
examples = ds["train"].filter(lambda x: x["id"] in train_ids)
|
57 |
+
# raise Exception(language_bcp_47)
|
58 |
task = ds["test"].filter(lambda x: x["id"] == common_ids_test[nr])[0]
|
59 |
return "fair-forward/arc-easy-autotranslated", examples, task
|
60 |
else:
|
evals/datasets_/truthfulqa.py
CHANGED
@@ -3,6 +3,13 @@ from collections import Counter, defaultdict
|
|
3 |
|
4 |
from langcodes import Language, standardize_tag
|
5 |
from rich import print
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from datasets_.util import _get_dataset_config_names, _load_dataset
|
8 |
|
@@ -28,3 +35,68 @@ def load_truthfulqa(language_bcp_47, nr):
|
|
28 |
return "masakhane/uhura-truthfulqa", examples, task
|
29 |
else:
|
30 |
return None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
from langcodes import Language, standardize_tag
|
5 |
from rich import print
|
6 |
+
from tqdm import tqdm
|
7 |
+
import asyncio
|
8 |
+
from tqdm.asyncio import tqdm_asyncio
|
9 |
+
import os
|
10 |
+
|
11 |
+
from datasets import Dataset, load_dataset
|
12 |
+
from models import translate_google, google_supported_languages
|
13 |
|
14 |
from datasets_.util import _get_dataset_config_names, _load_dataset
|
15 |
|
|
|
35 |
return "masakhane/uhura-truthfulqa", examples, task
|
36 |
else:
|
37 |
return None, None, None
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def translate_truthfulqa(languages):
|
42 |
+
human_translated = [*tags_uhura_truthfulqa.keys()]
|
43 |
+
untranslated = [
|
44 |
+
lang
|
45 |
+
for lang in languages["bcp_47"].values[:100]
|
46 |
+
if lang not in human_translated and lang in google_supported_languages
|
47 |
+
]
|
48 |
+
n_samples = 10
|
49 |
+
|
50 |
+
slug = "fair-forward/truthfulqa-autotranslated"
|
51 |
+
for lang in tqdm(untranslated):
|
52 |
+
# check if already exists on hub
|
53 |
+
try:
|
54 |
+
ds_lang = load_dataset(slug, lang)
|
55 |
+
except (ValueError, Exception):
|
56 |
+
print(f"Translating {lang}...")
|
57 |
+
for split in ["train", "test"]:
|
58 |
+
ds = _load_dataset(slug_uhura_truthfulqa, tags_uhura_truthfulqa["en"], split=split)
|
59 |
+
samples = []
|
60 |
+
if split == "train":
|
61 |
+
samples.extend(ds)
|
62 |
+
else:
|
63 |
+
for i in range(n_samples):
|
64 |
+
task = ds[i]
|
65 |
+
samples.append(task)
|
66 |
+
questions_tr = [
|
67 |
+
translate_google(s["question"], "en", lang) for s in samples
|
68 |
+
]
|
69 |
+
questions_tr = asyncio.run(tqdm_asyncio.gather(*questions_tr))
|
70 |
+
choices_texts_concatenated = []
|
71 |
+
for s in samples:
|
72 |
+
for choice in eval(s["choices"]):
|
73 |
+
choices_texts_concatenated.append(choice)
|
74 |
+
choices_tr = [
|
75 |
+
translate_google(c, "en", lang) for c in choices_texts_concatenated
|
76 |
+
]
|
77 |
+
choices_tr = asyncio.run(tqdm_asyncio.gather(*choices_tr))
|
78 |
+
# group into chunks of 4
|
79 |
+
choices_tr = [
|
80 |
+
choices_tr[i : i + 4] for i in range(0, len(choices_tr), 4)
|
81 |
+
]
|
82 |
+
|
83 |
+
ds_lang = Dataset.from_dict(
|
84 |
+
{
|
85 |
+
"subject": [s["subject"] for s in samples],
|
86 |
+
"question": questions_tr,
|
87 |
+
"choices": choices_tr,
|
88 |
+
"answer": [s["answer"] for s in samples],
|
89 |
+
}
|
90 |
+
)
|
91 |
+
ds_lang.push_to_hub(
|
92 |
+
slug,
|
93 |
+
split=split,
|
94 |
+
config_name=lang,
|
95 |
+
token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
|
96 |
+
)
|
97 |
+
ds_lang.to_json(
|
98 |
+
f"data/translations/mmlu/{lang}_{split}.json",
|
99 |
+
lines=False,
|
100 |
+
force_ascii=False,
|
101 |
+
indent=2,
|
102 |
+
)
|
evals/plots.py
CHANGED
@@ -45,7 +45,7 @@ pivot_df = pivot_df[[task for task in ordered_tasks if task in pivot_df.columns]
|
|
45 |
correlation_matrix = pivot_df.corr()
|
46 |
|
47 |
# Create the correlation plot
|
48 |
-
plt.figure(figsize=(
|
49 |
# Create mask for upper triangle including diagonal to show only lower triangle
|
50 |
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
|
51 |
|
@@ -53,7 +53,7 @@ mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
|
|
53 |
sns.heatmap(
|
54 |
correlation_matrix,
|
55 |
annot=True,
|
56 |
-
cmap='
|
57 |
center=0,
|
58 |
square=True,
|
59 |
mask=mask,
|
@@ -61,7 +61,6 @@ sns.heatmap(
|
|
61 |
fmt='.3f'
|
62 |
)
|
63 |
|
64 |
-
plt.title('Task Performance Correlation Matrix', fontsize=16, fontweight='bold')
|
65 |
plt.xlabel('Tasks', fontsize=12)
|
66 |
plt.ylabel('Tasks', fontsize=12)
|
67 |
plt.xticks(rotation=45, ha='right')
|
|
|
45 |
correlation_matrix = pivot_df.corr()
|
46 |
|
47 |
# Create the correlation plot
|
48 |
+
plt.figure(figsize=(8, 6))
|
49 |
# Create mask for upper triangle including diagonal to show only lower triangle
|
50 |
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
|
51 |
|
|
|
53 |
sns.heatmap(
|
54 |
correlation_matrix,
|
55 |
annot=True,
|
56 |
+
cmap='Blues',
|
57 |
center=0,
|
58 |
square=True,
|
59 |
mask=mask,
|
|
|
61 |
fmt='.3f'
|
62 |
)
|
63 |
|
|
|
64 |
plt.xlabel('Tasks', fontsize=12)
|
65 |
plt.ylabel('Tasks', fontsize=12)
|
66 |
plt.xticks(rotation=45, ha='right')
|