Update utils.py
Browse files
utils.py
CHANGED
@@ -1,636 +1,636 @@
|
|
1 |
-
import base64
|
2 |
-
from huggingface_hub import hf_hub_download
|
3 |
-
import fasttext
|
4 |
-
import os
|
5 |
-
import json
|
6 |
-
import pandas as pd
|
7 |
-
from sklearn.metrics import (
|
8 |
-
precision_score,
|
9 |
-
recall_score,
|
10 |
-
f1_score,
|
11 |
-
confusion_matrix,
|
12 |
-
balanced_accuracy_score,
|
13 |
-
matthews_corrcoef
|
14 |
-
)
|
15 |
-
import numpy as np
|
16 |
-
from datasets import load_dataset
|
17 |
-
|
18 |
-
# Constants
|
19 |
-
MODEL_REPO = "atlasia/Sfaya-Moroccan-Darija-vs-All"
|
20 |
-
BIN_FILENAME = "model_multi_v3_2fpr.bin"
|
21 |
-
BINARY_LEADERBOARD_FILE = "darija_leaderboard_binary.json"
|
22 |
-
MULTILINGUAL_LEADERBOARD_FILE = "darija_leaderboard_multilingual.json"
|
23 |
-
DATA_PATH = "atlasia/Arabic-LID-Leaderboard"
|
24 |
-
|
25 |
-
target_label = "Morocco"
|
26 |
-
is_binary = False
|
27 |
-
|
28 |
-
# Load test dataset
|
29 |
-
test_dataset = load_dataset(DATA_PATH, split='test')
|
30 |
-
|
31 |
-
# Supported dialects
|
32 |
-
all_target_languages = list(test_dataset.unique("dialect"))
|
33 |
-
supported_dialects = all_target_languages + ['All']
|
34 |
-
languages_to_display_one_vs_all = all_target_languages # everything except All
|
35 |
-
|
36 |
-
print(f'all_target_languages: {all_target_languages}')
|
37 |
-
|
38 |
-
metrics = [
|
39 |
-
'f1_score',
|
40 |
-
'precision',
|
41 |
-
'recall',
|
42 |
-
'specificity',
|
43 |
-
'false_positive_rate',
|
44 |
-
'false_negative_rate',
|
45 |
-
'negative_predictive_value',
|
46 |
-
'n_test_samples',
|
47 |
-
]
|
48 |
-
|
49 |
-
default_metrics = [
|
50 |
-
'f1_score',
|
51 |
-
'precision',
|
52 |
-
'recall',
|
53 |
-
'false_positive_rate',
|
54 |
-
'false_negative_rate'
|
55 |
-
]
|
56 |
-
|
57 |
-
# default language to display in one-vs-all leaderboard
|
58 |
-
default_languages = [
|
59 |
-
'Morocco',
|
60 |
-
'MSA',
|
61 |
-
'Egypt',
|
62 |
-
'Algeria',
|
63 |
-
'Tunisia',
|
64 |
-
'Levantine',
|
65 |
-
]
|
66 |
-
|
67 |
-
language_mapping_dict = {
|
68 |
-
'ace_Arab': 'Acehnese',
|
69 |
-
'acm_Arab': 'Mesopotamia', # 'Gilit Mesopotamian'
|
70 |
-
'aeb_Arab': 'Tunisia',
|
71 |
-
'ajp_Arab': 'Levantine', # 'South Levantine'
|
72 |
-
'apc_Arab': 'Levantine',
|
73 |
-
'arb_Arab': 'MSA',
|
74 |
-
'arq_Arab': 'Algeria',
|
75 |
-
'ars_Arab': 'Saudi', # Najdi is primarily Saudi Arabian
|
76 |
-
'ary_Arab': 'Morocco',
|
77 |
-
'arz_Arab': 'Egypt',
|
78 |
-
'ayp_Arab': 'Mesopotamia', # 'North Mesopotamian'
|
79 |
-
'azb_Arab': 'Azerbaijan', # South Azerbaijani pertains to this region
|
80 |
-
'bcc_Arab': 'Balochistan', # Southern Balochi is from Balochistan
|
81 |
-
'bjn_Arab': 'Indonesia', # Banjar is spoken in Indonesia
|
82 |
-
'brh_Arab': 'Pakistan', # Brahui is spoken in Pakistan
|
83 |
-
'ckb_Arab': 'Kurdistan', # Central Kurdish is mainly in Iraq
|
84 |
-
'fuv_Arab': 'Nigeria', # Hausa States Fulfulde
|
85 |
-
'glk_Arab': 'Iran', # Gilaki is spoken in Iran
|
86 |
-
'hac_Arab': 'Iran', # Gurani is also primarily spoken in Iran
|
87 |
-
'kas_Arab': 'Kashmir',
|
88 |
-
'knc_Arab': 'Nigeria', # Central Kanuri is in Nigeria
|
89 |
-
'lki_Arab': 'Iran', # Laki is from Iran
|
90 |
-
'lrc_Arab': 'Iran', # Northern Luri is from Iran
|
91 |
-
'min_Arab': 'Indonesia', # Minangkabau is spoken in Indonesia
|
92 |
-
'mzn_Arab': 'Iran', # Mazanderani is spoken in Iran
|
93 |
-
'ota_Arab': 'Turkey', # Ottoman Turkish
|
94 |
-
'pbt_Arab': 'Afghanistan', # Southern Pashto
|
95 |
-
'pnb_Arab': 'Pakistan', # Western Panjabi
|
96 |
-
'sdh_Arab': 'Iraq', # Southern Kurdish
|
97 |
-
'shu_Arab': 'Chad', # Chadian Arabic
|
98 |
-
'skr_Arab': 'Pakistan', # Saraiki
|
99 |
-
'snd_Arab': 'Pakistan', # Sindhi
|
100 |
-
'sus_Arab': 'Guinea', # Susu
|
101 |
-
'tuk_Arab': 'Turkmenistan', # Turkmen
|
102 |
-
'uig_Arab': 'Uighur (China)', # Uighur
|
103 |
-
'urd_Arab': 'Pakistan', # Urdu
|
104 |
-
'uzs_Arab': 'Uzbekistan', # Southern Uzbek
|
105 |
-
'zsm_Arab': 'Malaysia' # Standard Malay
|
106 |
-
}
|
107 |
-
|
108 |
-
def predict_label(text, model, language_mapping_dict, use_mapping=False):
|
109 |
-
# Remove any newline characters and strip whitespace
|
110 |
-
text = str(text).strip().replace('\n', ' ')
|
111 |
-
|
112 |
-
if text == '':
|
113 |
-
return 'Other'
|
114 |
-
|
115 |
-
try:
|
116 |
-
# Get top prediction
|
117 |
-
prediction = model.predict(text, 1)
|
118 |
-
|
119 |
-
# Extract label and remove __label__ prefix
|
120 |
-
label = prediction[0][0].replace('__label__', '')
|
121 |
-
|
122 |
-
# Extract confidence score
|
123 |
-
confidence = prediction[1][0]
|
124 |
-
|
125 |
-
# map label to language using language_mapping_dict
|
126 |
-
if use_mapping:
|
127 |
-
label = language_mapping_dict.get(label, 'Other')
|
128 |
-
return label
|
129 |
-
|
130 |
-
except Exception as e:
|
131 |
-
print(f"Error processing text: {text}")
|
132 |
-
print(f"Exception: {e}")
|
133 |
-
return {'prediction_label': 'Error', 'prediction_confidence': 0.0}
|
134 |
-
|
135 |
-
def compute_classification_metrics(test_dataset):
|
136 |
-
"""
|
137 |
-
Compute comprehensive classification metrics for each class.
|
138 |
-
|
139 |
-
Args:
|
140 |
-
data (pd.DataFrame): DataFrame containing 'dialect' as true labels and 'preds' as predicted labels.
|
141 |
-
|
142 |
-
Returns:
|
143 |
-
pd.DataFrame: DataFrame with detailed metrics for each class.
|
144 |
-
"""
|
145 |
-
# transform the dataset into a DataFrame
|
146 |
-
data = pd.DataFrame(test_dataset)
|
147 |
-
# Extract true labels and predictions
|
148 |
-
true_labels = list(data['dialect'])
|
149 |
-
predicted_labels = list(data['preds'])
|
150 |
-
|
151 |
-
# Handle all unique labels
|
152 |
-
labels = sorted(list(set(true_labels + predicted_labels)))
|
153 |
-
label_to_index = {label: index for index, label in enumerate(labels)}
|
154 |
-
|
155 |
-
# Convert labels to indices
|
156 |
-
true_indices = [label_to_index[label] for label in true_labels]
|
157 |
-
pred_indices = [label_to_index[label] for label in predicted_labels]
|
158 |
-
|
159 |
-
# Compute basic metrics
|
160 |
-
f1_scores = f1_score(true_indices, pred_indices, average=None, labels=range(len(labels)))
|
161 |
-
precision_scores = precision_score(true_indices, pred_indices, average=None, labels=range(len(labels)))
|
162 |
-
recall_scores = recall_score(true_indices, pred_indices, average=None, labels=range(len(labels)))
|
163 |
-
|
164 |
-
# Compute confusion matrix
|
165 |
-
conf_mat = confusion_matrix(true_indices, pred_indices, labels=range(len(labels)))
|
166 |
-
|
167 |
-
# Calculate various metrics per class
|
168 |
-
FP = conf_mat.sum(axis=0) - np.diag(conf_mat) # False Positives
|
169 |
-
FN = conf_mat.sum(axis=1) - np.diag(conf_mat) # False Negatives
|
170 |
-
TP = np.diag(conf_mat) # True Positives
|
171 |
-
TN = conf_mat.sum() - (FP + FN + TP) # True Negatives
|
172 |
-
|
173 |
-
# Calculate sample counts per class
|
174 |
-
samples_per_class = np.bincount(true_indices, minlength=len(labels))
|
175 |
-
|
176 |
-
# Calculate additional metrics
|
177 |
-
with np.errstate(divide='ignore', invalid='ignore'):
|
178 |
-
fp_rate = FP / (FP + TN) # False Positive Rate
|
179 |
-
fn_rate = FN / (FN + TP) # False Negative Rate
|
180 |
-
specificity = TN / (TN + FP) # True Negative Rate
|
181 |
-
npv = TN / (TN + FN) # Negative Predictive Value
|
182 |
-
|
183 |
-
# Replace NaN/inf with 0
|
184 |
-
metrics = [fp_rate, fn_rate, specificity, npv]
|
185 |
-
metrics = [np.nan_to_num(m, nan=0.0, posinf=0.0, neginf=0.0) for m in metrics]
|
186 |
-
fp_rate, fn_rate, specificity, npv = metrics
|
187 |
-
|
188 |
-
# Calculate overall metrics
|
189 |
-
balanced_acc = balanced_accuracy_score(true_indices, pred_indices)
|
190 |
-
mcc = matthews_corrcoef(true_indices, pred_indices)
|
191 |
-
|
192 |
-
# Compile results into a DataFrame
|
193 |
-
result_df = pd.DataFrame({
|
194 |
-
'country': labels,
|
195 |
-
'samples': samples_per_class,
|
196 |
-
'f1_score': f1_scores,
|
197 |
-
'precision': precision_scores,
|
198 |
-
'recall': recall_scores,
|
199 |
-
'specificity': specificity,
|
200 |
-
'false_positive_rate': fp_rate,
|
201 |
-
'false_negative_rate': fn_rate,
|
202 |
-
'true_positives': TP,
|
203 |
-
'false_positives': FP,
|
204 |
-
'true_negatives': TN,
|
205 |
-
'false_negatives': FN,
|
206 |
-
'negative_predictive_value': npv
|
207 |
-
})
|
208 |
-
|
209 |
-
# Sort by number of samples (descending)
|
210 |
-
result_df = result_df.sort_values('samples', ascending=False)
|
211 |
-
|
212 |
-
# Calculate and add summary metrics
|
213 |
-
summary_metrics = {
|
214 |
-
'macro_f1': f1_score(true_indices, pred_indices, average='macro'),
|
215 |
-
'weighted_f1': f1_score(true_indices, pred_indices, average='weighted'),
|
216 |
-
'micro_f1': f1_score(true_indices, pred_indices, average='micro'),
|
217 |
-
'balanced_accuracy': balanced_acc,
|
218 |
-
'matthews_correlation': mcc
|
219 |
-
}
|
220 |
-
|
221 |
-
# Format all numeric columns to 4 decimal places
|
222 |
-
numeric_cols = result_df.select_dtypes(include=[np.number]).columns
|
223 |
-
result_df[numeric_cols] = result_df[numeric_cols].round(4)
|
224 |
-
|
225 |
-
print(f'result_df: {result_df}')
|
226 |
-
|
227 |
-
return result_df, summary_metrics
|
228 |
-
|
229 |
-
def make_binary(dialect, target):
|
230 |
-
if dialect != target:
|
231 |
-
return 'Other'
|
232 |
-
return target
|
233 |
-
|
234 |
-
def run_eval_one_vs_all(data_test, TARGET_LANG='Morocco'):
|
235 |
-
|
236 |
-
# map to binary
|
237 |
-
df_test_preds = data_test.copy()
|
238 |
-
df_test_preds.loc[df_test_preds['dialect'] == TARGET_LANG, 'dialect'] = TARGET_LANG
|
239 |
-
df_test_preds.loc[df_test_preds['dialect'] != TARGET_LANG, 'dialect'] = 'Other'
|
240 |
-
|
241 |
-
# compute the fpr per dialect
|
242 |
-
dialect_counts = data_test.groupby('dialect')['dialect'].count().reset_index(name='size')
|
243 |
-
result_df = pd.merge(dialect_counts, data_test, on='dialect')
|
244 |
-
result_df = result_df.groupby(['dialect', 'size', 'preds'])['preds'].count()/result_df.groupby(['dialect', 'size'])['preds'].count()
|
245 |
-
result_df.sort_index(ascending=False, level='size', inplace=True)
|
246 |
-
|
247 |
-
# group by dialect and get the false positive rate
|
248 |
-
out = result_df.copy()
|
249 |
-
out.name = 'false_positive_rate'
|
250 |
-
out = out.reset_index()
|
251 |
-
out = out[out['preds']==TARGET_LANG].drop(columns=['preds', 'size'])
|
252 |
-
|
253 |
-
print(f'out for TARGET_LANG={TARGET_LANG} \n: {out}')
|
254 |
-
|
255 |
-
return out
|
256 |
-
|
257 |
-
def update_darija_one_vs_all_leaderboard(result_df, model_name, target_lang, BINARY_LEADERBOARD_FILE="darija_leaderboard_binary.json"):
|
258 |
-
try:
|
259 |
-
with open(BINARY_LEADERBOARD_FILE, "r") as f:
|
260 |
-
data = json.load(f)
|
261 |
-
except FileNotFoundError:
|
262 |
-
data = []
|
263 |
-
|
264 |
-
# Process the results for each dialect/country
|
265 |
-
for _, row in result_df.iterrows():
|
266 |
-
dialect = row['dialect']
|
267 |
-
# Skip 'Other' class, it is considered as the null space
|
268 |
-
if dialect == 'Other':
|
269 |
-
continue
|
270 |
-
|
271 |
-
# Find existing target_lang entry or create a new one
|
272 |
-
target_entry = next((item for item in data if target_lang in item), None)
|
273 |
-
if target_entry is None:
|
274 |
-
target_entry = {target_lang: {}}
|
275 |
-
data.append(target_entry)
|
276 |
-
|
277 |
-
# Get the country-specific data for this target language
|
278 |
-
country_data = target_entry[target_lang]
|
279 |
-
|
280 |
-
# Initialize the dialect/country entry if it doesn't exist
|
281 |
-
if dialect not in country_data:
|
282 |
-
country_data[dialect] = {}
|
283 |
-
|
284 |
-
# Update the model metrics under the model name for the given dialect
|
285 |
-
country_data[dialect][model_name] = float(row['false_positive_rate'])
|
286 |
-
|
287 |
-
# # Add the number of test samples, if not already present
|
288 |
-
# if "n_test_samples" not in country_data[dialect]:
|
289 |
-
# country_data[dialect]["n_test_samples"] = int(row['size'])
|
290 |
-
|
291 |
-
# Save updated leaderboard data
|
292 |
-
with open(BINARY_LEADERBOARD_FILE, "w") as f:
|
293 |
-
json.dump(data, f, indent=4)
|
294 |
-
|
295 |
-
def handle_evaluation(model_path, model_path_bin, use_mapping=False):
|
296 |
-
|
297 |
-
# download model and get the model path
|
298 |
-
model_path_hub = hf_hub_download(repo_id=model_path, filename=model_path_bin, cache_dir=None)
|
299 |
-
|
300 |
-
# Load the trained model
|
301 |
-
print(f"[INFO] Loading model from Path: {model_path_hub}, using version {model_path_bin}...")
|
302 |
-
model = fasttext.load_model(model_path_hub)
|
303 |
-
|
304 |
-
# Load the evaluation dataset
|
305 |
-
print(f"[INFO] Loading evaluation dataset from Path: {DATA_PATH}...")
|
306 |
-
eval_dataset = load_dataset(DATA_PATH, split='test')
|
307 |
-
|
308 |
-
# Transform to pandas DataFrame
|
309 |
-
print(f"[INFO] Converting evaluation dataset to Pandas DataFrame...")
|
310 |
-
df_eval = pd.DataFrame(eval_dataset)
|
311 |
-
|
312 |
-
# Predict labels using the model
|
313 |
-
print(f"[INFO] Running predictions...")
|
314 |
-
df_eval['preds'] = df_eval['text'].apply(lambda text: predict_label(text, model, language_mapping_dict, use_mapping=use_mapping))
|
315 |
-
|
316 |
-
# run the evaluation
|
317 |
-
result_df, _ = run_eval(df_eval)
|
318 |
-
# set the model name
|
319 |
-
model_name = model_path + '/' + model_path_bin
|
320 |
-
|
321 |
-
# update the multilingual leaderboard
|
322 |
-
update_darija_multilingual_leaderboard(result_df, model_name, MULTILINGUAL_LEADERBOARD_FILE)
|
323 |
-
|
324 |
-
for target_lang in all_target_languages:
|
325 |
-
result_df_one_vs_all =run_eval_one_vs_all(df_eval, TARGET_LANG=target_lang)
|
326 |
-
update_darija_one_vs_all_leaderboard(result_df_one_vs_all, model_name, target_lang, BINARY_LEADERBOARD_FILE)
|
327 |
-
|
328 |
-
# load the updated leaderboard tables
|
329 |
-
df_multilingual = load_leaderboard_multilingual()
|
330 |
-
df_one_vs_all = load_leaderboard_one_vs_all()
|
331 |
-
|
332 |
-
status_message = "**Evaluation now ended! 🤗**"
|
333 |
-
|
334 |
-
return create_leaderboard_display_multilingual(df_multilingual, target_label, default_metrics), status_message
|
335 |
-
|
336 |
-
def run_eval(df_eval):
|
337 |
-
"""Run evaluation on a dataset and compute metrics.
|
338 |
-
|
339 |
-
Args:
|
340 |
-
model: The model to evaluate.
|
341 |
-
DATA_PATH (str): Path to the dataset.
|
342 |
-
is_binary (bool): If True, evaluate as binary classification.
|
343 |
-
If False, evaluate as multi-class classification.
|
344 |
-
target_label (str): The target class label in binary mode.
|
345 |
-
|
346 |
-
Returns:
|
347 |
-
pd.DataFrame: A DataFrame containing evaluation metrics.
|
348 |
-
"""
|
349 |
-
|
350 |
-
# map to binary
|
351 |
-
df_eval_multilingual = df_eval.copy()
|
352 |
-
|
353 |
-
# now drop the columns that are not needed, i.e. 'text'
|
354 |
-
df_eval_multilingual = df_eval_multilingual.drop(columns=['text', 'metadata', 'dataset_source'])
|
355 |
-
|
356 |
-
# Compute evaluation metrics
|
357 |
-
print(f"[INFO] Computing metrics...")
|
358 |
-
result_df, _ = compute_classification_metrics(df_eval_multilingual)
|
359 |
-
|
360 |
-
# update_darija_multilingual_leaderboard(result_df, model_path, MULTILINGUAL_LEADERBOARD_FILE)
|
361 |
-
|
362 |
-
return result_df, df_eval_multilingual
|
363 |
-
|
364 |
-
def process_results_file(file, uploaded_model_name, base_path_save="./atlasia/submissions/", default_language='Morocco'):
|
365 |
-
try:
|
366 |
-
if file is None:
|
367 |
-
return "Please upload a file."
|
368 |
-
|
369 |
-
# Clean the model name to be safe for file paths
|
370 |
-
uploaded_model_name = uploaded_model_name.strip().replace(" ", "_")
|
371 |
-
print(f"[INFO] uploaded_model_name: {uploaded_model_name}")
|
372 |
-
|
373 |
-
# Create the directory for saving submissions
|
374 |
-
path_saving = os.path.join(base_path_save, uploaded_model_name)
|
375 |
-
os.makedirs(path_saving, exist_ok=True)
|
376 |
-
|
377 |
-
# Define the full path to save the file
|
378 |
-
saved_file_path = os.path.join(path_saving, 'submission.csv')
|
379 |
-
|
380 |
-
# Read the uploaded file as DataFrame
|
381 |
-
print(f"[INFO] Loading results...")
|
382 |
-
df_eval = pd.read_csv(file.name)
|
383 |
-
|
384 |
-
# Save the DataFrame
|
385 |
-
print(f"[INFO] Saving the file locally in: {saved_file_path}")
|
386 |
-
df_eval.to_csv(saved_file_path, index=False)
|
387 |
-
|
388 |
-
except Exception as e:
|
389 |
-
return f"Error processing file: {str(e)}"
|
390 |
-
|
391 |
-
# Compute evaluation metrics
|
392 |
-
print(f"[INFO] Computing metrics...")
|
393 |
-
result_df, _ = compute_classification_metrics(df_eval)
|
394 |
-
|
395 |
-
# Update the leaderboards
|
396 |
-
update_darija_multilingual_leaderboard(result_df, uploaded_model_name, MULTILINGUAL_LEADERBOARD_FILE)
|
397 |
-
|
398 |
-
# TODO: implement this ove_vs_all differently for people only submitting csv file. They need to submit two files, one for multi-lang and the other for one-vs-all
|
399 |
-
# result_df_one_vs_all = run_eval_one_vs_all(...)
|
400 |
-
# update_darija_one_vs_all_leaderboard(...)
|
401 |
-
|
402 |
-
# update the leaderboard table
|
403 |
-
df = load_leaderboard_multilingual()
|
404 |
-
|
405 |
-
return create_leaderboard_display_multilingual(df, default_language, default_metrics)
|
406 |
-
|
407 |
-
def update_darija_multilingual_leaderboard(result_df, model_name, MULTILINGUAL_LEADERBOARD_FILE="darija_leaderboard_multilingual.json"):
|
408 |
-
|
409 |
-
# Load leaderboard data
|
410 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
411 |
-
MULTILINGUAL_LEADERBOARD_FILE = os.path.join(current_dir, MULTILINGUAL_LEADERBOARD_FILE)
|
412 |
-
|
413 |
-
try:
|
414 |
-
with open(MULTILINGUAL_LEADERBOARD_FILE, "r") as f:
|
415 |
-
data = json.load(f)
|
416 |
-
except FileNotFoundError:
|
417 |
-
data = []
|
418 |
-
|
419 |
-
# Process the results for each dialect/country
|
420 |
-
for _, row in result_df.iterrows():
|
421 |
-
country = row['country']
|
422 |
-
# skip 'Other' class, it is considered as the null space
|
423 |
-
if country == 'Other':
|
424 |
-
continue
|
425 |
-
|
426 |
-
# Create metrics dictionary directly
|
427 |
-
metrics = {
|
428 |
-
'f1_score': float(row['f1_score']),
|
429 |
-
'precision': float(row['precision']),
|
430 |
-
'recall': float(row['recall']),
|
431 |
-
'specificity': float(row['specificity']),
|
432 |
-
'false_positive_rate': float(row['false_positive_rate']),
|
433 |
-
'false_negative_rate': float(row['false_negative_rate']),
|
434 |
-
'negative_predictive_value': float(row['negative_predictive_value']),
|
435 |
-
'n_test_samples': int(row['samples'])
|
436 |
-
}
|
437 |
-
|
438 |
-
# Find existing country entry or create new one
|
439 |
-
country_entry = next((item for item in data if country in item), None)
|
440 |
-
if country_entry is None:
|
441 |
-
country_entry = {country: {}}
|
442 |
-
data.append(country_entry)
|
443 |
-
|
444 |
-
# Update the model metrics directly under the model name
|
445 |
-
if country not in country_entry:
|
446 |
-
country_entry[country] = {}
|
447 |
-
country_entry[country][model_name] = metrics
|
448 |
-
|
449 |
-
# Save updated leaderboard data
|
450 |
-
with open(MULTILINGUAL_LEADERBOARD_FILE, "w") as f:
|
451 |
-
json.dump(data, f, indent=4)
|
452 |
-
|
453 |
-
|
454 |
-
def load_leaderboard_one_vs_all(BINARY_LEADERBOARD_FILE="darija_leaderboard_binary.json"):
|
455 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
456 |
-
BINARY_LEADERBOARD_FILE = os.path.join(current_dir, BINARY_LEADERBOARD_FILE)
|
457 |
-
|
458 |
-
with open(BINARY_LEADERBOARD_FILE, "r") as f:
|
459 |
-
data = json.load(f)
|
460 |
-
|
461 |
-
# Initialize lists to store the flattened data
|
462 |
-
rows = []
|
463 |
-
|
464 |
-
# Process each target language's data
|
465 |
-
for leaderboard_data in data:
|
466 |
-
for target_language, results in leaderboard_data.items():
|
467 |
-
for language, models in results.items():
|
468 |
-
|
469 |
-
for model_name, false_positive_rate in models.items():
|
470 |
-
|
471 |
-
row = {
|
472 |
-
'target_language': target_language,
|
473 |
-
'language': language,
|
474 |
-
'model': model_name,
|
475 |
-
'false_positive_rate': false_positive_rate,
|
476 |
-
}
|
477 |
-
# Add all metrics to the row
|
478 |
-
rows.append(row)
|
479 |
-
|
480 |
-
# Convert to DataFrame
|
481 |
-
df = pd.DataFrame(rows)
|
482 |
-
|
483 |
-
# Pivot the DataFrame to create the desired structure: all languages in columns and models in rows, and each (model, target_language, language) = false_positive_rate
|
484 |
-
df_pivot = df.pivot(index=['model', 'target_language'], columns='language', values='false_positive_rate').reset_index()
|
485 |
-
|
486 |
-
# print(f'df_pivot \n: {df_pivot}')
|
487 |
-
|
488 |
-
return df_pivot
|
489 |
-
|
490 |
-
def load_leaderboard_multilingual(MULTILINGUAL_LEADERBOARD_FILE="darija_leaderboard_multilingual.json"):
|
491 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
492 |
-
MULTILINGUAL_LEADERBOARD_FILE = os.path.join(current_dir, MULTILINGUAL_LEADERBOARD_FILE)
|
493 |
-
|
494 |
-
with open(MULTILINGUAL_LEADERBOARD_FILE, "r") as f:
|
495 |
-
data = json.load(f)
|
496 |
-
|
497 |
-
# Initialize lists to store the flattened data
|
498 |
-
rows = []
|
499 |
-
|
500 |
-
# Process each country's data
|
501 |
-
for country_data in data:
|
502 |
-
for country, models in country_data.items():
|
503 |
-
for model_name, metrics in models.items():
|
504 |
-
row = {
|
505 |
-
'country': country,
|
506 |
-
'model': model_name,
|
507 |
-
}
|
508 |
-
# Add all metrics to the row
|
509 |
-
row.update(metrics)
|
510 |
-
rows.append(row)
|
511 |
-
|
512 |
-
# Convert to DataFrame
|
513 |
-
df = pd.DataFrame(rows)
|
514 |
-
return df
|
515 |
-
|
516 |
-
def create_leaderboard_display_one_vs_all(df, target_language, selected_languages):
|
517 |
-
|
518 |
-
# Filter by target_language if specified
|
519 |
-
if target_language:
|
520 |
-
df = df[df['target_language'] == target_language]
|
521 |
-
|
522 |
-
# Remove the target_language from selected_languages
|
523 |
-
if target_language in selected_languages:
|
524 |
-
selected_languages = [lang for lang in selected_languages if lang != target_language]
|
525 |
-
|
526 |
-
# Select only the chosen languages (plus 'model' column)
|
527 |
-
columns_to_show = ['model'] + [language for language in selected_languages if language in df.columns]
|
528 |
-
|
529 |
-
# Sort by first selected metric by default
|
530 |
-
if selected_languages:
|
531 |
-
df = df.sort_values(by=selected_languages[0], ascending=False)
|
532 |
-
|
533 |
-
df = df[columns_to_show]
|
534 |
-
|
535 |
-
# Format numeric columns to 4 decimal places
|
536 |
-
numeric_cols = df.select_dtypes(include=['float64']).columns
|
537 |
-
df[numeric_cols] = df[numeric_cols].round(4)
|
538 |
-
|
539 |
-
return df, selected_languages
|
540 |
-
|
541 |
-
|
542 |
-
def create_leaderboard_display_multilingual(df, selected_country, selected_metrics):
|
543 |
-
# Filter by country if specified
|
544 |
-
if selected_country and selected_country.upper() != 'ALL':
|
545 |
-
# print(f"Filtering leaderboard by country: {selected_country}")
|
546 |
-
df = df[df['country'] == selected_country]
|
547 |
-
df = df.drop(columns=['country'])
|
548 |
-
|
549 |
-
# Select only the chosen metrics (plus 'model' column)
|
550 |
-
columns_to_show = ['model'] + [metric for metric in selected_metrics if metric in df.columns]
|
551 |
-
|
552 |
-
else:
|
553 |
-
# Select all metrics (plus 'country' and 'model' columns), if no country is selected or 'All' is selected for ease of comparison
|
554 |
-
columns_to_show = ['model', 'country'] + selected_metrics
|
555 |
-
|
556 |
-
# Sort by first selected metric by default
|
557 |
-
if selected_metrics:
|
558 |
-
df = df.sort_values(by=selected_metrics[0], ascending=False)
|
559 |
-
|
560 |
-
df = df[columns_to_show]
|
561 |
-
|
562 |
-
# Format numeric columns to 4 decimal places
|
563 |
-
numeric_cols = df.select_dtypes(include=['float64']).columns
|
564 |
-
df[numeric_cols] = df[numeric_cols].round(4)
|
565 |
-
|
566 |
-
return df
|
567 |
-
|
568 |
-
def update_leaderboard_multilingual(country, selected_metrics):
|
569 |
-
if not selected_metrics: # If no metrics selected, show all
|
570 |
-
selected_metrics = metrics
|
571 |
-
df = load_leaderboard_multilingual()
|
572 |
-
display_df = create_leaderboard_display_multilingual(df, country, selected_metrics)
|
573 |
-
return display_df
|
574 |
-
|
575 |
-
def update_leaderboard_one_vs_all(target_language, selected_languages):
|
576 |
-
if not selected_languages: # If no language selected, show all defaults
|
577 |
-
selected_languages = default_languages
|
578 |
-
df = load_leaderboard_one_vs_all()
|
579 |
-
display_df, selected_languages = create_leaderboard_display_one_vs_all(df, target_language, selected_languages)
|
580 |
-
# to improve visibility in case the user chooses multiple language leading to many columns, the `model` column must remain fixed
|
581 |
-
# display_df = render_fixed_columns(display_df)
|
582 |
-
return display_df, selected_languages
|
583 |
-
|
584 |
-
def encode_image_to_base64(image_path):
|
585 |
-
with open(image_path, "rb") as image_file:
|
586 |
-
encoded_string = base64.b64encode(image_file.read()).decode()
|
587 |
-
return encoded_string
|
588 |
-
|
589 |
-
def create_html_image(image_path):
|
590 |
-
# Get base64 string of image
|
591 |
-
img_base64 = encode_image_to_base64(image_path)
|
592 |
-
|
593 |
-
# Create HTML string with embedded image and centering styles
|
594 |
-
html_string = f"""
|
595 |
-
<div style="display: flex; justify-content: center; align-items: center; width: 100%; text-align: center;">
|
596 |
-
<div style="max-width: 800px; margin: auto;">
|
597 |
-
<img src="data:image/jpeg;base64,{img_base64}"
|
598 |
-
style="max-width: 75%; height: auto; display: block; margin: 0 auto; margin-top: 50px;"
|
599 |
-
alt="Displayed Image">
|
600 |
-
</div>
|
601 |
-
</div>
|
602 |
-
"""
|
603 |
-
return html_string
|
604 |
-
|
605 |
-
# Function to render HTML table with fixed 'model' column
|
606 |
-
def render_fixed_columns(df):
|
607 |
-
style = """
|
608 |
-
<style>
|
609 |
-
.table-container {
|
610 |
-
overflow-x: auto;
|
611 |
-
position: relative;
|
612 |
-
white-space: nowrap;
|
613 |
-
}
|
614 |
-
table {
|
615 |
-
border-collapse: collapse;
|
616 |
-
width: 100%;
|
617 |
-
}
|
618 |
-
th, td {
|
619 |
-
border: 1px solid black;
|
620 |
-
padding: 8px;
|
621 |
-
text-align: left;
|
622 |
-
}
|
623 |
-
th.fixed, td.fixed {
|
624 |
-
position: sticky;
|
625 |
-
left: 0;
|
626 |
-
background-color: white;
|
627 |
-
z-index: 2;
|
628 |
-
}
|
629 |
-
</style>
|
630 |
-
"""
|
631 |
-
table_html = df.to_html(index=False).replace(
|
632 |
-
"<th>model</th>", '<th class="fixed">model</th>'
|
633 |
-
).replace(
|
634 |
-
'<td>', '<td class="fixed">', 1
|
635 |
-
)
|
636 |
return f"{style}<div class='table-container'>{table_html}</div>"
|
|
|
1 |
+
import base64
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
import fasttext
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import pandas as pd
|
7 |
+
from sklearn.metrics import (
|
8 |
+
precision_score,
|
9 |
+
recall_score,
|
10 |
+
f1_score,
|
11 |
+
confusion_matrix,
|
12 |
+
balanced_accuracy_score,
|
13 |
+
matthews_corrcoef
|
14 |
+
)
|
15 |
+
import numpy as np
|
16 |
+
from datasets import load_dataset
|
17 |
+
|
18 |
+
# Constants
|
19 |
+
MODEL_REPO = "atlasia/Sfaya-Moroccan-Darija-vs-All"
|
20 |
+
BIN_FILENAME = "model_multi_v3_2fpr.bin"
|
21 |
+
BINARY_LEADERBOARD_FILE = "darija_leaderboard_binary.json"
|
22 |
+
MULTILINGUAL_LEADERBOARD_FILE = "darija_leaderboard_multilingual.json"
|
23 |
+
DATA_PATH = "atlasia/Arabic-LID-Leaderboard"
|
24 |
+
|
25 |
+
target_label = "Morocco"
|
26 |
+
is_binary = False
|
27 |
+
|
28 |
+
# Load test dataset
|
29 |
+
test_dataset = load_dataset(DATA_PATH, split='test')
|
30 |
+
|
31 |
+
# Supported dialects
|
32 |
+
all_target_languages = list(test_dataset.unique("dialect"))
|
33 |
+
supported_dialects = all_target_languages + ['All']
|
34 |
+
languages_to_display_one_vs_all = all_target_languages # everything except All
|
35 |
+
|
36 |
+
print(f'all_target_languages: {all_target_languages}')
|
37 |
+
|
38 |
+
metrics = [
|
39 |
+
'f1_score',
|
40 |
+
'precision',
|
41 |
+
'recall',
|
42 |
+
'specificity',
|
43 |
+
'false_positive_rate',
|
44 |
+
'false_negative_rate',
|
45 |
+
'negative_predictive_value',
|
46 |
+
'n_test_samples',
|
47 |
+
]
|
48 |
+
|
49 |
+
default_metrics = [
|
50 |
+
'f1_score',
|
51 |
+
'precision',
|
52 |
+
'recall',
|
53 |
+
'false_positive_rate',
|
54 |
+
'false_negative_rate'
|
55 |
+
]
|
56 |
+
|
57 |
+
# default language to display in one-vs-all leaderboard
|
58 |
+
default_languages = [
|
59 |
+
#'Morocco',
|
60 |
+
'MSA',
|
61 |
+
#'Egypt',
|
62 |
+
#'Algeria',
|
63 |
+
#'Tunisia',
|
64 |
+
#'Levantine',
|
65 |
+
]
|
66 |
+
|
67 |
+
language_mapping_dict = {
|
68 |
+
'ace_Arab': 'Acehnese',
|
69 |
+
'acm_Arab': 'Mesopotamia', # 'Gilit Mesopotamian'
|
70 |
+
'aeb_Arab': 'Tunisia',
|
71 |
+
'ajp_Arab': 'Levantine', # 'South Levantine'
|
72 |
+
'apc_Arab': 'Levantine',
|
73 |
+
'arb_Arab': 'MSA',
|
74 |
+
'arq_Arab': 'Algeria',
|
75 |
+
'ars_Arab': 'Saudi', # Najdi is primarily Saudi Arabian
|
76 |
+
'ary_Arab': 'Morocco',
|
77 |
+
'arz_Arab': 'Egypt',
|
78 |
+
'ayp_Arab': 'Mesopotamia', # 'North Mesopotamian'
|
79 |
+
'azb_Arab': 'Azerbaijan', # South Azerbaijani pertains to this region
|
80 |
+
'bcc_Arab': 'Balochistan', # Southern Balochi is from Balochistan
|
81 |
+
'bjn_Arab': 'Indonesia', # Banjar is spoken in Indonesia
|
82 |
+
'brh_Arab': 'Pakistan', # Brahui is spoken in Pakistan
|
83 |
+
'ckb_Arab': 'Kurdistan', # Central Kurdish is mainly in Iraq
|
84 |
+
'fuv_Arab': 'Nigeria', # Hausa States Fulfulde
|
85 |
+
'glk_Arab': 'Iran', # Gilaki is spoken in Iran
|
86 |
+
'hac_Arab': 'Iran', # Gurani is also primarily spoken in Iran
|
87 |
+
'kas_Arab': 'Kashmir',
|
88 |
+
'knc_Arab': 'Nigeria', # Central Kanuri is in Nigeria
|
89 |
+
'lki_Arab': 'Iran', # Laki is from Iran
|
90 |
+
'lrc_Arab': 'Iran', # Northern Luri is from Iran
|
91 |
+
'min_Arab': 'Indonesia', # Minangkabau is spoken in Indonesia
|
92 |
+
'mzn_Arab': 'Iran', # Mazanderani is spoken in Iran
|
93 |
+
'ota_Arab': 'Turkey', # Ottoman Turkish
|
94 |
+
'pbt_Arab': 'Afghanistan', # Southern Pashto
|
95 |
+
'pnb_Arab': 'Pakistan', # Western Panjabi
|
96 |
+
'sdh_Arab': 'Iraq', # Southern Kurdish
|
97 |
+
'shu_Arab': 'Chad', # Chadian Arabic
|
98 |
+
'skr_Arab': 'Pakistan', # Saraiki
|
99 |
+
'snd_Arab': 'Pakistan', # Sindhi
|
100 |
+
'sus_Arab': 'Guinea', # Susu
|
101 |
+
'tuk_Arab': 'Turkmenistan', # Turkmen
|
102 |
+
'uig_Arab': 'Uighur (China)', # Uighur
|
103 |
+
'urd_Arab': 'Pakistan', # Urdu
|
104 |
+
'uzs_Arab': 'Uzbekistan', # Southern Uzbek
|
105 |
+
'zsm_Arab': 'Malaysia' # Standard Malay
|
106 |
+
}
|
107 |
+
|
108 |
+
def predict_label(text, model, language_mapping_dict, use_mapping=False):
|
109 |
+
# Remove any newline characters and strip whitespace
|
110 |
+
text = str(text).strip().replace('\n', ' ')
|
111 |
+
|
112 |
+
if text == '':
|
113 |
+
return 'Other'
|
114 |
+
|
115 |
+
try:
|
116 |
+
# Get top prediction
|
117 |
+
prediction = model.predict(text, 1)
|
118 |
+
|
119 |
+
# Extract label and remove __label__ prefix
|
120 |
+
label = prediction[0][0].replace('__label__', '')
|
121 |
+
|
122 |
+
# Extract confidence score
|
123 |
+
confidence = prediction[1][0]
|
124 |
+
|
125 |
+
# map label to language using language_mapping_dict
|
126 |
+
if use_mapping:
|
127 |
+
label = language_mapping_dict.get(label, 'Other')
|
128 |
+
return label
|
129 |
+
|
130 |
+
except Exception as e:
|
131 |
+
print(f"Error processing text: {text}")
|
132 |
+
print(f"Exception: {e}")
|
133 |
+
return {'prediction_label': 'Error', 'prediction_confidence': 0.0}
|
134 |
+
|
135 |
+
def compute_classification_metrics(test_dataset):
|
136 |
+
"""
|
137 |
+
Compute comprehensive classification metrics for each class.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
data (pd.DataFrame): DataFrame containing 'dialect' as true labels and 'preds' as predicted labels.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
pd.DataFrame: DataFrame with detailed metrics for each class.
|
144 |
+
"""
|
145 |
+
# transform the dataset into a DataFrame
|
146 |
+
data = pd.DataFrame(test_dataset)
|
147 |
+
# Extract true labels and predictions
|
148 |
+
true_labels = list(data['dialect'])
|
149 |
+
predicted_labels = list(data['preds'])
|
150 |
+
|
151 |
+
# Handle all unique labels
|
152 |
+
labels = sorted(list(set(true_labels + predicted_labels)))
|
153 |
+
label_to_index = {label: index for index, label in enumerate(labels)}
|
154 |
+
|
155 |
+
# Convert labels to indices
|
156 |
+
true_indices = [label_to_index[label] for label in true_labels]
|
157 |
+
pred_indices = [label_to_index[label] for label in predicted_labels]
|
158 |
+
|
159 |
+
# Compute basic metrics
|
160 |
+
f1_scores = f1_score(true_indices, pred_indices, average=None, labels=range(len(labels)))
|
161 |
+
precision_scores = precision_score(true_indices, pred_indices, average=None, labels=range(len(labels)))
|
162 |
+
recall_scores = recall_score(true_indices, pred_indices, average=None, labels=range(len(labels)))
|
163 |
+
|
164 |
+
# Compute confusion matrix
|
165 |
+
conf_mat = confusion_matrix(true_indices, pred_indices, labels=range(len(labels)))
|
166 |
+
|
167 |
+
# Calculate various metrics per class
|
168 |
+
FP = conf_mat.sum(axis=0) - np.diag(conf_mat) # False Positives
|
169 |
+
FN = conf_mat.sum(axis=1) - np.diag(conf_mat) # False Negatives
|
170 |
+
TP = np.diag(conf_mat) # True Positives
|
171 |
+
TN = conf_mat.sum() - (FP + FN + TP) # True Negatives
|
172 |
+
|
173 |
+
# Calculate sample counts per class
|
174 |
+
samples_per_class = np.bincount(true_indices, minlength=len(labels))
|
175 |
+
|
176 |
+
# Calculate additional metrics
|
177 |
+
with np.errstate(divide='ignore', invalid='ignore'):
|
178 |
+
fp_rate = FP / (FP + TN) # False Positive Rate
|
179 |
+
fn_rate = FN / (FN + TP) # False Negative Rate
|
180 |
+
specificity = TN / (TN + FP) # True Negative Rate
|
181 |
+
npv = TN / (TN + FN) # Negative Predictive Value
|
182 |
+
|
183 |
+
# Replace NaN/inf with 0
|
184 |
+
metrics = [fp_rate, fn_rate, specificity, npv]
|
185 |
+
metrics = [np.nan_to_num(m, nan=0.0, posinf=0.0, neginf=0.0) for m in metrics]
|
186 |
+
fp_rate, fn_rate, specificity, npv = metrics
|
187 |
+
|
188 |
+
# Calculate overall metrics
|
189 |
+
balanced_acc = balanced_accuracy_score(true_indices, pred_indices)
|
190 |
+
mcc = matthews_corrcoef(true_indices, pred_indices)
|
191 |
+
|
192 |
+
# Compile results into a DataFrame
|
193 |
+
result_df = pd.DataFrame({
|
194 |
+
'country': labels,
|
195 |
+
'samples': samples_per_class,
|
196 |
+
'f1_score': f1_scores,
|
197 |
+
'precision': precision_scores,
|
198 |
+
'recall': recall_scores,
|
199 |
+
'specificity': specificity,
|
200 |
+
'false_positive_rate': fp_rate,
|
201 |
+
'false_negative_rate': fn_rate,
|
202 |
+
'true_positives': TP,
|
203 |
+
'false_positives': FP,
|
204 |
+
'true_negatives': TN,
|
205 |
+
'false_negatives': FN,
|
206 |
+
'negative_predictive_value': npv
|
207 |
+
})
|
208 |
+
|
209 |
+
# Sort by number of samples (descending)
|
210 |
+
result_df = result_df.sort_values('samples', ascending=False)
|
211 |
+
|
212 |
+
# Calculate and add summary metrics
|
213 |
+
summary_metrics = {
|
214 |
+
'macro_f1': f1_score(true_indices, pred_indices, average='macro'),
|
215 |
+
'weighted_f1': f1_score(true_indices, pred_indices, average='weighted'),
|
216 |
+
'micro_f1': f1_score(true_indices, pred_indices, average='micro'),
|
217 |
+
'balanced_accuracy': balanced_acc,
|
218 |
+
'matthews_correlation': mcc
|
219 |
+
}
|
220 |
+
|
221 |
+
# Format all numeric columns to 4 decimal places
|
222 |
+
numeric_cols = result_df.select_dtypes(include=[np.number]).columns
|
223 |
+
result_df[numeric_cols] = result_df[numeric_cols].round(4)
|
224 |
+
|
225 |
+
print(f'result_df: {result_df}')
|
226 |
+
|
227 |
+
return result_df, summary_metrics
|
228 |
+
|
229 |
+
def make_binary(dialect, target):
|
230 |
+
if dialect != target:
|
231 |
+
return 'Other'
|
232 |
+
return target
|
233 |
+
|
234 |
+
def run_eval_one_vs_all(data_test, TARGET_LANG='Morocco'):
|
235 |
+
|
236 |
+
# map to binary
|
237 |
+
df_test_preds = data_test.copy()
|
238 |
+
df_test_preds.loc[df_test_preds['dialect'] == TARGET_LANG, 'dialect'] = TARGET_LANG
|
239 |
+
df_test_preds.loc[df_test_preds['dialect'] != TARGET_LANG, 'dialect'] = 'Other'
|
240 |
+
|
241 |
+
# compute the fpr per dialect
|
242 |
+
dialect_counts = data_test.groupby('dialect')['dialect'].count().reset_index(name='size')
|
243 |
+
result_df = pd.merge(dialect_counts, data_test, on='dialect')
|
244 |
+
result_df = result_df.groupby(['dialect', 'size', 'preds'])['preds'].count()/result_df.groupby(['dialect', 'size'])['preds'].count()
|
245 |
+
result_df.sort_index(ascending=False, level='size', inplace=True)
|
246 |
+
|
247 |
+
# group by dialect and get the false positive rate
|
248 |
+
out = result_df.copy()
|
249 |
+
out.name = 'false_positive_rate'
|
250 |
+
out = out.reset_index()
|
251 |
+
out = out[out['preds']==TARGET_LANG].drop(columns=['preds', 'size'])
|
252 |
+
|
253 |
+
print(f'out for TARGET_LANG={TARGET_LANG} \n: {out}')
|
254 |
+
|
255 |
+
return out
|
256 |
+
|
257 |
+
def update_darija_one_vs_all_leaderboard(result_df, model_name, target_lang, BINARY_LEADERBOARD_FILE="darija_leaderboard_binary.json"):
|
258 |
+
try:
|
259 |
+
with open(BINARY_LEADERBOARD_FILE, "r") as f:
|
260 |
+
data = json.load(f)
|
261 |
+
except FileNotFoundError:
|
262 |
+
data = []
|
263 |
+
|
264 |
+
# Process the results for each dialect/country
|
265 |
+
for _, row in result_df.iterrows():
|
266 |
+
dialect = row['dialect']
|
267 |
+
# Skip 'Other' class, it is considered as the null space
|
268 |
+
if dialect == 'Other':
|
269 |
+
continue
|
270 |
+
|
271 |
+
# Find existing target_lang entry or create a new one
|
272 |
+
target_entry = next((item for item in data if target_lang in item), None)
|
273 |
+
if target_entry is None:
|
274 |
+
target_entry = {target_lang: {}}
|
275 |
+
data.append(target_entry)
|
276 |
+
|
277 |
+
# Get the country-specific data for this target language
|
278 |
+
country_data = target_entry[target_lang]
|
279 |
+
|
280 |
+
# Initialize the dialect/country entry if it doesn't exist
|
281 |
+
if dialect not in country_data:
|
282 |
+
country_data[dialect] = {}
|
283 |
+
|
284 |
+
# Update the model metrics under the model name for the given dialect
|
285 |
+
country_data[dialect][model_name] = float(row['false_positive_rate'])
|
286 |
+
|
287 |
+
# # Add the number of test samples, if not already present
|
288 |
+
# if "n_test_samples" not in country_data[dialect]:
|
289 |
+
# country_data[dialect]["n_test_samples"] = int(row['size'])
|
290 |
+
|
291 |
+
# Save updated leaderboard data
|
292 |
+
with open(BINARY_LEADERBOARD_FILE, "w") as f:
|
293 |
+
json.dump(data, f, indent=4)
|
294 |
+
|
295 |
+
def handle_evaluation(model_path, model_path_bin, use_mapping=False):
|
296 |
+
|
297 |
+
# download model and get the model path
|
298 |
+
model_path_hub = hf_hub_download(repo_id=model_path, filename=model_path_bin, cache_dir=None)
|
299 |
+
|
300 |
+
# Load the trained model
|
301 |
+
print(f"[INFO] Loading model from Path: {model_path_hub}, using version {model_path_bin}...")
|
302 |
+
model = fasttext.load_model(model_path_hub)
|
303 |
+
|
304 |
+
# Load the evaluation dataset
|
305 |
+
print(f"[INFO] Loading evaluation dataset from Path: {DATA_PATH}...")
|
306 |
+
eval_dataset = load_dataset(DATA_PATH, split='test')
|
307 |
+
|
308 |
+
# Transform to pandas DataFrame
|
309 |
+
print(f"[INFO] Converting evaluation dataset to Pandas DataFrame...")
|
310 |
+
df_eval = pd.DataFrame(eval_dataset)
|
311 |
+
|
312 |
+
# Predict labels using the model
|
313 |
+
print(f"[INFO] Running predictions...")
|
314 |
+
df_eval['preds'] = df_eval['text'].apply(lambda text: predict_label(text, model, language_mapping_dict, use_mapping=use_mapping))
|
315 |
+
|
316 |
+
# run the evaluation
|
317 |
+
result_df, _ = run_eval(df_eval)
|
318 |
+
# set the model name
|
319 |
+
model_name = model_path + '/' + model_path_bin
|
320 |
+
|
321 |
+
# update the multilingual leaderboard
|
322 |
+
update_darija_multilingual_leaderboard(result_df, model_name, MULTILINGUAL_LEADERBOARD_FILE)
|
323 |
+
|
324 |
+
for target_lang in all_target_languages:
|
325 |
+
result_df_one_vs_all =run_eval_one_vs_all(df_eval, TARGET_LANG=target_lang)
|
326 |
+
update_darija_one_vs_all_leaderboard(result_df_one_vs_all, model_name, target_lang, BINARY_LEADERBOARD_FILE)
|
327 |
+
|
328 |
+
# load the updated leaderboard tables
|
329 |
+
df_multilingual = load_leaderboard_multilingual()
|
330 |
+
df_one_vs_all = load_leaderboard_one_vs_all()
|
331 |
+
|
332 |
+
status_message = "**Evaluation now ended! 🤗**"
|
333 |
+
|
334 |
+
return create_leaderboard_display_multilingual(df_multilingual, target_label, default_metrics), status_message
|
335 |
+
|
336 |
+
def run_eval(df_eval):
|
337 |
+
"""Run evaluation on a dataset and compute metrics.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
model: The model to evaluate.
|
341 |
+
DATA_PATH (str): Path to the dataset.
|
342 |
+
is_binary (bool): If True, evaluate as binary classification.
|
343 |
+
If False, evaluate as multi-class classification.
|
344 |
+
target_label (str): The target class label in binary mode.
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
pd.DataFrame: A DataFrame containing evaluation metrics.
|
348 |
+
"""
|
349 |
+
|
350 |
+
# map to binary
|
351 |
+
df_eval_multilingual = df_eval.copy()
|
352 |
+
|
353 |
+
# now drop the columns that are not needed, i.e. 'text'
|
354 |
+
df_eval_multilingual = df_eval_multilingual.drop(columns=['text', 'metadata', 'dataset_source'])
|
355 |
+
|
356 |
+
# Compute evaluation metrics
|
357 |
+
print(f"[INFO] Computing metrics...")
|
358 |
+
result_df, _ = compute_classification_metrics(df_eval_multilingual)
|
359 |
+
|
360 |
+
# update_darija_multilingual_leaderboard(result_df, model_path, MULTILINGUAL_LEADERBOARD_FILE)
|
361 |
+
|
362 |
+
return result_df, df_eval_multilingual
|
363 |
+
|
364 |
+
def process_results_file(file, uploaded_model_name, base_path_save="./atlasia/submissions/", default_language='Morocco'):
|
365 |
+
try:
|
366 |
+
if file is None:
|
367 |
+
return "Please upload a file."
|
368 |
+
|
369 |
+
# Clean the model name to be safe for file paths
|
370 |
+
uploaded_model_name = uploaded_model_name.strip().replace(" ", "_")
|
371 |
+
print(f"[INFO] uploaded_model_name: {uploaded_model_name}")
|
372 |
+
|
373 |
+
# Create the directory for saving submissions
|
374 |
+
path_saving = os.path.join(base_path_save, uploaded_model_name)
|
375 |
+
os.makedirs(path_saving, exist_ok=True)
|
376 |
+
|
377 |
+
# Define the full path to save the file
|
378 |
+
saved_file_path = os.path.join(path_saving, 'submission.csv')
|
379 |
+
|
380 |
+
# Read the uploaded file as DataFrame
|
381 |
+
print(f"[INFO] Loading results...")
|
382 |
+
df_eval = pd.read_csv(file.name)
|
383 |
+
|
384 |
+
# Save the DataFrame
|
385 |
+
print(f"[INFO] Saving the file locally in: {saved_file_path}")
|
386 |
+
df_eval.to_csv(saved_file_path, index=False)
|
387 |
+
|
388 |
+
except Exception as e:
|
389 |
+
return f"Error processing file: {str(e)}"
|
390 |
+
|
391 |
+
# Compute evaluation metrics
|
392 |
+
print(f"[INFO] Computing metrics...")
|
393 |
+
result_df, _ = compute_classification_metrics(df_eval)
|
394 |
+
|
395 |
+
# Update the leaderboards
|
396 |
+
update_darija_multilingual_leaderboard(result_df, uploaded_model_name, MULTILINGUAL_LEADERBOARD_FILE)
|
397 |
+
|
398 |
+
# TODO: implement this ove_vs_all differently for people only submitting csv file. They need to submit two files, one for multi-lang and the other for one-vs-all
|
399 |
+
# result_df_one_vs_all = run_eval_one_vs_all(...)
|
400 |
+
# update_darija_one_vs_all_leaderboard(...)
|
401 |
+
|
402 |
+
# update the leaderboard table
|
403 |
+
df = load_leaderboard_multilingual()
|
404 |
+
|
405 |
+
return create_leaderboard_display_multilingual(df, default_language, default_metrics)
|
406 |
+
|
407 |
+
def update_darija_multilingual_leaderboard(result_df, model_name, MULTILINGUAL_LEADERBOARD_FILE="darija_leaderboard_multilingual.json"):
|
408 |
+
|
409 |
+
# Load leaderboard data
|
410 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
411 |
+
MULTILINGUAL_LEADERBOARD_FILE = os.path.join(current_dir, MULTILINGUAL_LEADERBOARD_FILE)
|
412 |
+
|
413 |
+
try:
|
414 |
+
with open(MULTILINGUAL_LEADERBOARD_FILE, "r") as f:
|
415 |
+
data = json.load(f)
|
416 |
+
except FileNotFoundError:
|
417 |
+
data = []
|
418 |
+
|
419 |
+
# Process the results for each dialect/country
|
420 |
+
for _, row in result_df.iterrows():
|
421 |
+
country = row['country']
|
422 |
+
# skip 'Other' class, it is considered as the null space
|
423 |
+
if country == 'Other':
|
424 |
+
continue
|
425 |
+
|
426 |
+
# Create metrics dictionary directly
|
427 |
+
metrics = {
|
428 |
+
'f1_score': float(row['f1_score']),
|
429 |
+
'precision': float(row['precision']),
|
430 |
+
'recall': float(row['recall']),
|
431 |
+
'specificity': float(row['specificity']),
|
432 |
+
'false_positive_rate': float(row['false_positive_rate']),
|
433 |
+
'false_negative_rate': float(row['false_negative_rate']),
|
434 |
+
'negative_predictive_value': float(row['negative_predictive_value']),
|
435 |
+
'n_test_samples': int(row['samples'])
|
436 |
+
}
|
437 |
+
|
438 |
+
# Find existing country entry or create new one
|
439 |
+
country_entry = next((item for item in data if country in item), None)
|
440 |
+
if country_entry is None:
|
441 |
+
country_entry = {country: {}}
|
442 |
+
data.append(country_entry)
|
443 |
+
|
444 |
+
# Update the model metrics directly under the model name
|
445 |
+
if country not in country_entry:
|
446 |
+
country_entry[country] = {}
|
447 |
+
country_entry[country][model_name] = metrics
|
448 |
+
|
449 |
+
# Save updated leaderboard data
|
450 |
+
with open(MULTILINGUAL_LEADERBOARD_FILE, "w") as f:
|
451 |
+
json.dump(data, f, indent=4)
|
452 |
+
|
453 |
+
|
454 |
+
def load_leaderboard_one_vs_all(BINARY_LEADERBOARD_FILE="darija_leaderboard_binary.json"):
|
455 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
456 |
+
BINARY_LEADERBOARD_FILE = os.path.join(current_dir, BINARY_LEADERBOARD_FILE)
|
457 |
+
|
458 |
+
with open(BINARY_LEADERBOARD_FILE, "r") as f:
|
459 |
+
data = json.load(f)
|
460 |
+
|
461 |
+
# Initialize lists to store the flattened data
|
462 |
+
rows = []
|
463 |
+
|
464 |
+
# Process each target language's data
|
465 |
+
for leaderboard_data in data:
|
466 |
+
for target_language, results in leaderboard_data.items():
|
467 |
+
for language, models in results.items():
|
468 |
+
|
469 |
+
for model_name, false_positive_rate in models.items():
|
470 |
+
|
471 |
+
row = {
|
472 |
+
'target_language': target_language,
|
473 |
+
'language': language,
|
474 |
+
'model': model_name,
|
475 |
+
'false_positive_rate': false_positive_rate,
|
476 |
+
}
|
477 |
+
# Add all metrics to the row
|
478 |
+
rows.append(row)
|
479 |
+
|
480 |
+
# Convert to DataFrame
|
481 |
+
df = pd.DataFrame(rows)
|
482 |
+
|
483 |
+
# Pivot the DataFrame to create the desired structure: all languages in columns and models in rows, and each (model, target_language, language) = false_positive_rate
|
484 |
+
df_pivot = df.pivot(index=['model', 'target_language'], columns='language', values='false_positive_rate').reset_index()
|
485 |
+
|
486 |
+
# print(f'df_pivot \n: {df_pivot}')
|
487 |
+
|
488 |
+
return df_pivot
|
489 |
+
|
490 |
+
def load_leaderboard_multilingual(MULTILINGUAL_LEADERBOARD_FILE="darija_leaderboard_multilingual.json"):
|
491 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
492 |
+
MULTILINGUAL_LEADERBOARD_FILE = os.path.join(current_dir, MULTILINGUAL_LEADERBOARD_FILE)
|
493 |
+
|
494 |
+
with open(MULTILINGUAL_LEADERBOARD_FILE, "r") as f:
|
495 |
+
data = json.load(f)
|
496 |
+
|
497 |
+
# Initialize lists to store the flattened data
|
498 |
+
rows = []
|
499 |
+
|
500 |
+
# Process each country's data
|
501 |
+
for country_data in data:
|
502 |
+
for country, models in country_data.items():
|
503 |
+
for model_name, metrics in models.items():
|
504 |
+
row = {
|
505 |
+
'country': country,
|
506 |
+
'model': model_name,
|
507 |
+
}
|
508 |
+
# Add all metrics to the row
|
509 |
+
row.update(metrics)
|
510 |
+
rows.append(row)
|
511 |
+
|
512 |
+
# Convert to DataFrame
|
513 |
+
df = pd.DataFrame(rows)
|
514 |
+
return df
|
515 |
+
|
516 |
+
def create_leaderboard_display_one_vs_all(df, target_language, selected_languages):
|
517 |
+
|
518 |
+
# Filter by target_language if specified
|
519 |
+
if target_language:
|
520 |
+
df = df[df['target_language'] == target_language]
|
521 |
+
|
522 |
+
# Remove the target_language from selected_languages
|
523 |
+
if target_language in selected_languages:
|
524 |
+
selected_languages = [lang for lang in selected_languages if lang != target_language]
|
525 |
+
|
526 |
+
# Select only the chosen languages (plus 'model' column)
|
527 |
+
columns_to_show = ['model'] + [language for language in selected_languages if language in df.columns]
|
528 |
+
|
529 |
+
# Sort by first selected metric by default
|
530 |
+
if selected_languages:
|
531 |
+
df = df.sort_values(by=selected_languages[0], ascending=False)
|
532 |
+
|
533 |
+
df = df[columns_to_show]
|
534 |
+
|
535 |
+
# Format numeric columns to 4 decimal places
|
536 |
+
numeric_cols = df.select_dtypes(include=['float64']).columns
|
537 |
+
df[numeric_cols] = df[numeric_cols].round(4)
|
538 |
+
|
539 |
+
return df, selected_languages
|
540 |
+
|
541 |
+
|
542 |
+
def create_leaderboard_display_multilingual(df, selected_country, selected_metrics):
|
543 |
+
# Filter by country if specified
|
544 |
+
if selected_country and selected_country.upper() != 'ALL':
|
545 |
+
# print(f"Filtering leaderboard by country: {selected_country}")
|
546 |
+
df = df[df['country'] == selected_country]
|
547 |
+
df = df.drop(columns=['country'])
|
548 |
+
|
549 |
+
# Select only the chosen metrics (plus 'model' column)
|
550 |
+
columns_to_show = ['model'] + [metric for metric in selected_metrics if metric in df.columns]
|
551 |
+
|
552 |
+
else:
|
553 |
+
# Select all metrics (plus 'country' and 'model' columns), if no country is selected or 'All' is selected for ease of comparison
|
554 |
+
columns_to_show = ['model', 'country'] + selected_metrics
|
555 |
+
|
556 |
+
# Sort by first selected metric by default
|
557 |
+
if selected_metrics:
|
558 |
+
df = df.sort_values(by=selected_metrics[0], ascending=False)
|
559 |
+
|
560 |
+
df = df[columns_to_show]
|
561 |
+
|
562 |
+
# Format numeric columns to 4 decimal places
|
563 |
+
numeric_cols = df.select_dtypes(include=['float64']).columns
|
564 |
+
df[numeric_cols] = df[numeric_cols].round(4)
|
565 |
+
|
566 |
+
return df
|
567 |
+
|
568 |
+
def update_leaderboard_multilingual(country, selected_metrics):
|
569 |
+
if not selected_metrics: # If no metrics selected, show all
|
570 |
+
selected_metrics = metrics
|
571 |
+
df = load_leaderboard_multilingual()
|
572 |
+
display_df = create_leaderboard_display_multilingual(df, country, selected_metrics)
|
573 |
+
return display_df
|
574 |
+
|
575 |
+
def update_leaderboard_one_vs_all(target_language, selected_languages):
|
576 |
+
if not selected_languages: # If no language selected, show all defaults
|
577 |
+
selected_languages = default_languages
|
578 |
+
df = load_leaderboard_one_vs_all()
|
579 |
+
display_df, selected_languages = create_leaderboard_display_one_vs_all(df, target_language, selected_languages)
|
580 |
+
# to improve visibility in case the user chooses multiple language leading to many columns, the `model` column must remain fixed
|
581 |
+
# display_df = render_fixed_columns(display_df)
|
582 |
+
return display_df, selected_languages
|
583 |
+
|
584 |
+
def encode_image_to_base64(image_path):
|
585 |
+
with open(image_path, "rb") as image_file:
|
586 |
+
encoded_string = base64.b64encode(image_file.read()).decode()
|
587 |
+
return encoded_string
|
588 |
+
|
589 |
+
def create_html_image(image_path):
|
590 |
+
# Get base64 string of image
|
591 |
+
img_base64 = encode_image_to_base64(image_path)
|
592 |
+
|
593 |
+
# Create HTML string with embedded image and centering styles
|
594 |
+
html_string = f"""
|
595 |
+
<div style="display: flex; justify-content: center; align-items: center; width: 100%; text-align: center;">
|
596 |
+
<div style="max-width: 800px; margin: auto;">
|
597 |
+
<img src="data:image/jpeg;base64,{img_base64}"
|
598 |
+
style="max-width: 75%; height: auto; display: block; margin: 0 auto; margin-top: 50px;"
|
599 |
+
alt="Displayed Image">
|
600 |
+
</div>
|
601 |
+
</div>
|
602 |
+
"""
|
603 |
+
return html_string
|
604 |
+
|
605 |
+
# Function to render HTML table with fixed 'model' column
|
606 |
+
def render_fixed_columns(df):
|
607 |
+
style = """
|
608 |
+
<style>
|
609 |
+
.table-container {
|
610 |
+
overflow-x: auto;
|
611 |
+
position: relative;
|
612 |
+
white-space: nowrap;
|
613 |
+
}
|
614 |
+
table {
|
615 |
+
border-collapse: collapse;
|
616 |
+
width: 100%;
|
617 |
+
}
|
618 |
+
th, td {
|
619 |
+
border: 1px solid black;
|
620 |
+
padding: 8px;
|
621 |
+
text-align: left;
|
622 |
+
}
|
623 |
+
th.fixed, td.fixed {
|
624 |
+
position: sticky;
|
625 |
+
left: 0;
|
626 |
+
background-color: white;
|
627 |
+
z-index: 2;
|
628 |
+
}
|
629 |
+
</style>
|
630 |
+
"""
|
631 |
+
table_html = df.to_html(index=False).replace(
|
632 |
+
"<th>model</th>", '<th class="fixed">model</th>'
|
633 |
+
).replace(
|
634 |
+
'<td>', '<td class="fixed">', 1
|
635 |
+
)
|
636 |
return f"{style}<div class='table-container'>{table_html}</div>"
|