Spaces:
Running
Running
import gradio as gr | |
from collections import defaultdict | |
import os | |
import base64 | |
from datasets import ( | |
Dataset, | |
load_dataset, | |
) | |
import pandas as pd | |
from collections import defaultdict | |
import itertools | |
TOKEN = os.environ['TOKEN'] | |
MASKED_LM_MODELS = [ | |
"BounharAbdelaziz/XLM-RoBERTa-Morocco", | |
"SI2M-Lab/DarijaBERT", | |
"BounharAbdelaziz/ModernBERT-Morocco", | |
"google-bert/bert-base-multilingual-cased", | |
"FacebookAI/xlm-roberta-large", | |
"aubmindlab/bert-base-arabertv02", | |
] | |
CAUSAL_LM_MODELS = [ | |
"BounharAbdelaziz/Al-Atlas-LLM-0.5B", | |
"Qwen/Qwen2.5-0.5B", | |
"tiiuae/Falcon3-1B-Base", | |
"MBZUAI-Paris/Atlas-Chat-2B", | |
] | |
def encode_image_to_base64(image_path): | |
"""Encode an image or GIF file to base64.""" | |
with open(image_path, "rb") as file: | |
encoded_string = base64.b64encode(file.read()).decode() | |
return encoded_string | |
def create_html_media(media_path, is_gif=False): | |
"""Create HTML for displaying an image or GIF.""" | |
media_base64 = encode_image_to_base64(media_path) | |
media_type = "gif" if is_gif else "jpeg" | |
html_string = f""" | |
<div style="display: flex; justify-content: center; align-items: center; width: 100%; text-align: center;"> | |
<div style="max-width: 450px; margin: auto;"> | |
<img src="data:image/{media_type};base64,{media_base64}" | |
style="max-width: 75%; height: auto; display: block; margin: 0 auto; margin-top: 50px;" | |
alt="Displayed Media"> | |
</div> | |
</div> | |
""" | |
return html_string | |
class LMBattleArena: | |
def __init__(self, dataset_path, saving_freq=25): | |
"""Initialize battle arena with dataset""" | |
self.df = pd.read_csv(dataset_path) | |
self.current_index = 0 | |
self.saving_freq = saving_freq # save the results in csv/push to hub every saving_freq evaluations | |
self.evaluation_results_masked = [] | |
self.evaluation_results_causal = [] | |
self.model_scores = defaultdict(lambda: {'wins': 0, 'total_comparisons': 0}) | |
# Generate all possible model pairs | |
self.masked_model_pairs = list(itertools.combinations(MASKED_LM_MODELS, 2)) | |
self.causal_model_pairs = list(itertools.combinations(CAUSAL_LM_MODELS, 2)) | |
# Pair indices to track which pair is being evaluated | |
self.masked_pair_idx = 0 | |
self.causal_pair_idx = 0 | |
# To track which rows have been evaluated for which model pairs | |
self.row_model_pairs_evaluated = set() # Using a simple set | |
def get_next_battle_pair(self, is_causal): | |
"""Retrieve next pair of summaries for comparison ensuring all pairs are evaluated""" | |
if self.current_index >= len(self.df): | |
# Reset index to go through dataset again with remaining model pairs | |
self.current_index = 0 | |
# If we've gone through all model pairs for all rows, we're done | |
if is_causal and self.causal_pair_idx >= len(self.causal_model_pairs): | |
return None | |
elif not is_causal and self.masked_pair_idx >= len(self.masked_model_pairs): | |
return None | |
row = self.df.iloc[self.current_index] | |
# Get the current model pair to evaluate | |
if is_causal: | |
# Check if we've evaluated all causal model pairs | |
if self.causal_pair_idx >= len(self.causal_model_pairs): | |
# Move to next row and reset pair index | |
self.current_index += 1 | |
self.causal_pair_idx = 0 | |
# Try again with the next row | |
return self.get_next_battle_pair(is_causal) | |
model_pair = self.causal_model_pairs[self.causal_pair_idx] | |
pair_key = f"{self.current_index}_causal_{self.causal_pair_idx}" | |
# Check if this row-pair combination has been evaluated | |
if pair_key in self.row_model_pairs_evaluated: | |
# Move to next pair | |
self.causal_pair_idx += 1 | |
return self.get_next_battle_pair(is_causal) | |
# Mark this row-pair combination as evaluated | |
self.row_model_pairs_evaluated.add(pair_key) | |
# Move to next pair for next evaluation | |
self.causal_pair_idx += 1 | |
# Check if we've gone through all pairs for this row | |
if self.causal_pair_idx >= len(self.causal_model_pairs): | |
# Reset pair index and move to next row for next evaluation | |
self.causal_pair_idx = 0 | |
self.current_index += 1 | |
else: | |
# Similar logic for masked models | |
if self.masked_pair_idx >= len(self.masked_model_pairs): | |
self.current_index += 1 | |
self.masked_pair_idx = 0 | |
return self.get_next_battle_pair(is_causal) | |
model_pair = self.masked_model_pairs[self.masked_pair_idx] | |
pair_key = f"{self.current_index}_masked_{self.masked_pair_idx}" | |
if pair_key in self.row_model_pairs_evaluated: | |
self.masked_pair_idx += 1 | |
return self.get_next_battle_pair(is_causal) | |
self.row_model_pairs_evaluated.add(pair_key) | |
self.masked_pair_idx += 1 | |
if self.masked_pair_idx >= len(self.masked_model_pairs): | |
self.masked_pair_idx = 0 | |
self.current_index += 1 | |
# Prepare the battle data with the selected model pair | |
battle_data = { | |
'prompt': row['masked_sentence'] if not is_causal else row['causal_sentence'], | |
'model_1': row[model_pair[0]], | |
'model_2': row[model_pair[1]], | |
'model1_name': model_pair[0], | |
'model2_name': model_pair[1] | |
} | |
return battle_data | |
def record_evaluation(self, preferred_models, input_text, output1, output2, model1_name, model2_name, is_causal): | |
"""Record user's model preference and update scores""" | |
self.model_scores[model1_name]['total_comparisons'] += 1 | |
self.model_scores[model2_name]['total_comparisons'] += 1 | |
if preferred_models == "Both Good": | |
self.model_scores[model1_name]['wins'] += 1 | |
self.model_scores[model2_name]['wins'] += 1 | |
elif preferred_models == "Model A": # Maps to first model | |
self.model_scores[model1_name]['wins'] += 1 | |
elif preferred_models == "Model B": # Maps to second model | |
self.model_scores[model2_name]['wins'] += 1 | |
# "Both Bad" case - no wins recorded | |
evaluation = { | |
'input_text': input_text, | |
'output1': output1, | |
'output2': output2, | |
'model1_name': model1_name, | |
'model2_name': model2_name, | |
'preferred_models': preferred_models | |
} | |
if is_causal: | |
self.evaluation_results_causal.append(evaluation) | |
else: | |
self.evaluation_results_masked.append(evaluation) | |
# Calculate the total number of evaluations | |
total_evaluations = len(self.evaluation_results_causal) + len(self.evaluation_results_masked) | |
# Save results periodically | |
if total_evaluations % self.saving_freq == 0: | |
self.save_results() | |
return self.get_model_scores_df(is_causal) | |
def save_results(self): | |
"""Save the evaluation results to Hub and CSV""" | |
results_df = self.get_model_scores_df(is_causal=True) # Get the latest scores | |
results_dataset = Dataset.from_pandas(results_df) | |
results_dataset.push_to_hub('atlasia/Res-Moroccan-Darija-LLM-Battle-Al-Atlas', private=True, token=TOKEN) | |
results_df.to_csv('human_eval_results.csv') | |
# Also save the raw evaluation results | |
masked_df = pd.DataFrame(self.evaluation_results_masked) | |
causal_df = pd.DataFrame(self.evaluation_results_causal) | |
if not masked_df.empty: | |
masked_df.to_csv('masked_evaluations.csv') | |
if not causal_df.empty: | |
causal_df.to_csv('causal_evaluations.csv') | |
def get_model_scores_df(self, is_causal): | |
"""Convert model scores to DataFrame""" | |
scores_data = [] | |
for model, stats in self.model_scores.items(): | |
if is_causal: | |
if model not in CAUSAL_LM_MODELS: | |
continue | |
else: | |
if model not in MASKED_LM_MODELS: | |
continue | |
win_rate = (stats['wins'] / stats['total_comparisons'] * 100) if stats['total_comparisons'] > 0 else 0 | |
scores_data.append({ | |
'Model': model, | |
'Wins': stats['wins'], | |
'Total Comparisons': stats['total_comparisons'], | |
'Win Rate (%)': round(win_rate, 2) | |
}) | |
results_df = pd.DataFrame(scores_data) | |
print("Generated DataFrame:\n", results_df) # Debugging print | |
# if 'Win Rate (%)' not in results_df.columns: | |
# raise ValueError("Win Rate (%) column is missing from DataFrame!") | |
return results_df | |
def create_battle_arena(dataset_path, is_gif, is_causal): | |
arena = LMBattleArena(dataset_path) | |
def battle_round(is_causal): | |
battle_data = arena.get_next_battle_pair(is_causal) | |
if battle_data is None: | |
return "All model pairs have been evaluated for all examples!", "", "", "", "", gr.DataFrame(visible=False) | |
return ( | |
battle_data['prompt'], | |
battle_data['model_1'], | |
battle_data['model_2'], | |
battle_data['model1_name'], | |
battle_data['model2_name'], | |
gr.DataFrame(visible=True) | |
) | |
def submit_preference(input_text, output_1, output_2, model1_name, model2_name, preferred_models, is_causal): | |
scores_df = arena.record_evaluation( | |
preferred_models, input_text, output_1, output_2, model1_name, model2_name, is_causal | |
) | |
next_battle = battle_round(is_causal) | |
return (*next_battle[:-1], scores_df) | |
with gr.Blocks(css="footer{display:none !important}") as demo: | |
# Rest of the code remains the same | |
base_path = os.path.dirname(__file__) | |
local_image_path = os.path.join(base_path, 'battle_leaderboard.gif') | |
gr.HTML(create_html_media(local_image_path, is_gif=is_gif)) | |
with gr.Tabs(): | |
with gr.Tab("Masked LM Battle Arena"): | |
gr.Markdown("# π€ Pretrained SmolLMs Battle Arena") | |
# Use gr.State to store the boolean value without displaying it | |
is_causal = gr.State(value=False) | |
input_text = gr.Textbox( | |
label="Input prompt", | |
interactive=False, | |
) | |
with gr.Row(): | |
output_1 = gr.Textbox( | |
label="Model A", | |
interactive=False | |
) | |
model1_name = gr.State() # Hidden state for model1 name | |
with gr.Row(): | |
output_2 = gr.Textbox( | |
label="Model B", | |
interactive=False | |
) | |
model2_name = gr.State() # Hidden state for model2 name | |
preferred_models = gr.Radio( | |
label="Which model is better?", | |
choices=["Model A", "Model B", "Both Good", "Both Bad"] | |
) | |
submit_btn = gr.Button("Vote", variant="primary") | |
scores_table = gr.DataFrame( | |
headers=['Model', 'Wins', 'Total Comparisons', 'Win Rate (%)'], | |
label="π Leaderboard" | |
) | |
submit_btn.click( | |
submit_preference, | |
inputs=[input_text, output_1, output_2, model1_name, model2_name, preferred_models, is_causal], | |
outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table] | |
) | |
demo.load( | |
battle_round, | |
inputs=[is_causal], | |
outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table] | |
) | |
with gr.Tab("Causal LM Battle Arena"): | |
gr.Markdown("# π€ Pretrained SmolLMs Battle Arena") | |
# Use gr.State to store the boolean value without displaying it | |
is_causal = gr.State(value=True) | |
input_text = gr.Textbox( | |
label="Input prompt", | |
interactive=False, | |
) | |
with gr.Row(): | |
output_1 = gr.Textbox( | |
label="Model A", | |
interactive=False | |
) | |
model1_name = gr.State() # Hidden state for model1 name | |
with gr.Row(): | |
output_2 = gr.Textbox( | |
label="Model B", | |
interactive=False | |
) | |
model2_name = gr.State() # Hidden state for model2 name | |
preferred_models = gr.Radio( | |
label="Which model is better?", | |
choices=["Model A", "Model B", "Both Good", "Both Bad"] | |
) | |
submit_btn = gr.Button("Vote", variant="primary") | |
scores_table = gr.DataFrame( | |
headers=['Model', 'Wins', 'Total Comparisons', 'Win Rate (%)'], | |
label="π Leaderboard" | |
) | |
submit_btn.click( | |
submit_preference, | |
inputs=[input_text, output_1, output_2, model1_name, model2_name, preferred_models, is_causal], | |
outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table] | |
) | |
demo.load( | |
battle_round, | |
inputs=[is_causal], | |
outputs=[input_text, output_1, output_2, model1_name, model2_name, scores_table] | |
) | |
return demo | |
if __name__ == "__main__": | |
# inference device | |
device = "cpu" | |
dataset_path = 'human_eval_dataset.csv' | |
is_gif = True | |
# load the existing dataset that contains outputs of the LMs | |
human_eval_dataset = load_dataset("atlasia/LM-Moroccan-Darija-Bench", split='test', token=TOKEN).to_csv(dataset_path) # atlasia/Moroccan-Darija-LLM-Battle-Al-Atlas | |
demo = create_battle_arena(dataset_path, is_gif, is_causal=False) | |
demo.launch(debug=True) |