import ast import pandas as pd import gradio as gr import litellm import plotly.express as px from collections import defaultdict from datetime import datetime import os from datasets import load_dataset import sqlite3 from typing import Dict, List, Optional, Tuple from dataclasses import dataclass from pathlib import Path import logging from plotly.graph_objects import Figure # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) @dataclass class EvaluationResult: accuracy: float subject_accuracy: Dict[str, float] detailed_results: List[Dict] class DatabaseManager: def __init__(self, db_path: str = 'afrimmlu_results.db'): self.db_path = db_path self._initialize_database() def _initialize_database(self) -> None: """Initialize SQLite database with required tables.""" try: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS summary_results ( id INTEGER PRIMARY KEY AUTOINCREMENT, language TEXT NOT NULL, subject TEXT NOT NULL, accuracy REAL NOT NULL, timestamp TEXT NOT NULL ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS detailed_results ( id INTEGER PRIMARY KEY AUTOINCREMENT, language TEXT NOT NULL, timestamp TEXT NOT NULL, subject TEXT NOT NULL, question TEXT NOT NULL, model_answer TEXT, correct_answer TEXT NOT NULL, is_correct INTEGER NOT NULL, total_tokens INTEGER ) ''') conn.commit() except sqlite3.Error as e: logger.error(f"Database initialization failed: {str(e)}") raise def save_results(self, language: str, summary_results: Dict[str, float], detailed_results: List[Dict]) -> None: """Save evaluation results to database.""" try: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() timestamp = datetime.now().isoformat() # Save summary results cursor.executemany(''' INSERT INTO summary_results (language, subject, accuracy, timestamp) VALUES (?, ?, ?, ?) ''', [(language, subject, accuracy, timestamp) for subject, accuracy in summary_results.items()]) # Save detailed results cursor.executemany(''' INSERT INTO detailed_results ( language, timestamp, subject, question, model_answer, correct_answer, is_correct, total_tokens ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', [(language, result['timestamp'], result['subject'], result['question'], result['model_answer'], result['correct_answer'], int(result['is_correct']), result['total_tokens']) for result in detailed_results]) conn.commit() except sqlite3.Error as e: logger.error(f"Failed to save results to database: {str(e)}") raise def query(self, query: str) -> pd.DataFrame: """Execute SQL query and return results as DataFrame.""" try: with sqlite3.connect(self.db_path) as conn: return pd.read_sql_query(query, conn) except sqlite3.Error as e: logger.error(f"Query execution failed: {str(e)}") return pd.DataFrame({'Error': [str(e)]}) class AfriMMLUEvaluator: def __init__(self, model_name: str = "deepseek/deepseek-chat"): self.model_name = model_name self.db_manager = DatabaseManager() def load_data(self, language_code: str = "swa") -> Optional[List[Dict]]: """Load AfriMMLU dataset for specified language.""" try: dataset = load_dataset( 'masakhane/afrimmlu', language_code, token=os.getenv('HF_TOKEN') ) return dataset['test'].to_list() except Exception as e: logger.error(f"Failed to load dataset for {language_code}: {str(e)}") return None @staticmethod def preprocess_data(test_data: List[Dict]) -> List[Dict]: """Preprocess dataset to convert choices field to list.""" preprocessed_data = [] for example in test_data: try: if isinstance(example['choices'], str): choices_str = example['choices'].strip("'\"").replace("\\'", "'") example['choices'] = ast.literal_eval(choices_str) preprocessed_data.append(example) except (ValueError, SyntaxError) as e: logger.warning(f"Skipping invalid choices: {example['choices']}") continue return preprocessed_data def evaluate(self, test_data: List[Dict], language: str) -> EvaluationResult: """Evaluate model on AfriMMLU dataset.""" results = [] correct = 0 total = 0 subject_results = defaultdict(lambda: {"correct": 0, "total": 0}) for example in test_data: try: prompt = self._create_prompt(example) response = litellm.completion( model=self.model_name, messages=[{"role": "user", "content": prompt}] ) model_answer = self._parse_model_answer(response.choices[0].message.content) is_correct = model_answer == example['answer'].upper() if is_correct: correct += 1 subject_results[example['subject']]["correct"] += 1 total += 1 subject_results[example['subject']]["total"] += 1 results.append({ 'timestamp': datetime.now().isoformat(), 'subject': example['subject'], 'question': example['question'], 'model_answer': model_answer, 'correct_answer': example['answer'].upper(), 'is_correct': is_correct, 'total_tokens': response.usage.total_tokens }) except Exception as e: logger.warning(f"Error processing question: {str(e)}") continue accuracy = (correct / total * 100) if total > 0 else 0 subject_accuracy = { subject: (stats["correct"] / stats["total"] * 100) if stats["total"] > 0 else 0 for subject, stats in subject_results.items() } self.db_manager.save_results(language, {**subject_accuracy, 'Overall': accuracy}, results) return EvaluationResult( accuracy=accuracy, subject_accuracy=subject_accuracy, detailed_results=results ) @staticmethod def _create_prompt(example: Dict) -> str: """Create formatted prompt for model evaluation.""" return ( f"Answer the following multiple-choice question. " f"Return only the letter corresponding to the correct answer (A, B, C, or D).\n" f"Question: {example['question']}\n" f"Options:\n" f"A. {example['choices'][0]}\n" f"B. {example['choices'][1]}\n" f"C. {example['choices'][2]}\n" f"D. {example['choices'][3]}\n" f"Answer:" ) @staticmethod def _parse_model_answer(output: str) -> Optional[str]: """Parse model output to extract answer letter.""" output = output.strip().upper() for char in output: if char in ['A', 'B', 'C', 'D']: return char return None class VisualizationManager: @staticmethod def create_visualization(results: EvaluationResult) -> Tuple[pd.DataFrame, Figure]: """Create visualization from evaluation results.""" summary_data = [ {'Subject': subject, 'Accuracy (%)': accuracy} for subject, accuracy in results.subject_accuracy.items() ] summary_data.append({'Subject': 'Overall', 'Accuracy (%)': results.accuracy}) summary_df = pd.DataFrame(summary_data) fig = px.bar( summary_df, x='Subject', y='Accuracy (%)', title='AfriMMLU Evaluation Results', labels={'Subject': 'Subject', 'Accuracy (%)': 'Accuracy (%)'}, template='plotly_white' ) fig.update_layout( xaxis_tickangle=-45, showlegend=False, height=600, margin=dict(b=200) ) return summary_df, fig def create_gradio_interface() -> gr.Blocks: """Create Gradio interface for AfriMMLU evaluation.""" evaluator = AfriMMLUEvaluator() vis_manager = VisualizationManager() language_options = { "swa": "Swahili", "yor": "Yoruba", "wol": "Wolof", "lin": "Lingala", "ewe": "Ewe", "ibo": "Igbo" } with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# AfriMMLU Evaluation Dashboard") with gr.Tabs(): with gr.Tab("Model Evaluation"): with gr.Row(): language_input = gr.Dropdown( choices=list(language_options.keys()), label="Select Language", value="swa" ) model_input = gr.Dropdown( choices=["deepseek/deepseek-chat"], label="Select Model", value="deepseek/deepseek-chat" ) evaluate_btn = gr.Button("Evaluate", variant="primary") summary_table = gr.Dataframe(label="Summary Results") summary_plot = gr.Plot(label="Performance by Subject") detailed_results = gr.Dataframe(label="Detailed Results", wrap=True) with gr.Tab("Database Analysis"): example_queries = gr.Dropdown( choices=[ "SELECT language, AVG(accuracy) as avg_accuracy FROM summary_results WHERE subject='Overall' GROUP BY language", "SELECT subject, AVG(accuracy) as avg_accuracy FROM summary_results GROUP BY subject", "SELECT language, subject, accuracy, timestamp FROM summary_results ORDER BY timestamp DESC LIMIT 10", "SELECT language, COUNT(*) as total_questions, SUM(is_correct) as correct_answers FROM detailed_results GROUP BY language", "SELECT subject, COUNT(*) as total_evaluations FROM summary_results GROUP BY subject" ], label="Example Queries" ) query_input = gr.Textbox( label="SQL Query", placeholder="Enter your SQL query here", lines=3 ) query_button = gr.Button("Run Query", variant="primary") query_output = gr.Dataframe(label="Query Results", wrap=True) gr.Markdown(""" ### Available Tables: - summary_results (id, language, subject, accuracy, timestamp) - detailed_results (id, language, timestamp, subject, question, model_answer, correct_answer, is_correct, total_tokens) """) def evaluate_language(language_code: str, model_name: str): evaluator.model_name = model_name test_data = evaluator.load_data(language_code) if not test_data: return None, None, None preprocessed_data = evaluator.preprocess_data(test_data) results = evaluator.evaluate(preprocessed_data, language_code) summary_df, plot = vis_manager.create_visualization(results) detailed_df = pd.DataFrame(results.detailed_results) return summary_df, plot, detailed_df evaluate_btn.click( fn=evaluate_language, inputs=[language_input, model_input], outputs=[summary_table, summary_plot, detailed_results] ) example_queries.change( fn=lambda x: x, inputs=[example_queries], outputs=[query_input] ) query_button.click( fn=evaluator.db_manager.query, inputs=[query_input], outputs=[query_output] ) return demo if __name__ == "__main__": try: # Validate environment variables required_env_vars = ['DEEPSEEK_API_KEY', 'HF_TOKEN'] for var in required_env_vars: if not os.getenv(var): raise EnvironmentError(f"Missing required environment variable: {var}") demo = create_gradio_interface() demo.launch(share=True) except Exception as e: logger.error(f"Application failed to start: {str(e)}") raise