from datasets import load_dataset, Dataset
import os 
from datasets import load_dataset
from datasets.utils.logging import disable_progress_bar
from constants import column_names, RANKING_COLUMN, ORDERED_COLUMN_NAMES
from utils_display import make_clickable_model

import random 
disable_progress_bar()
import math 
import json 
from tqdm import tqdm
import numpy as np
import os 

from eval_utils import *

summary_file = "ZeroEval-main/result_dirs/zebra-grid.summary.json"
result_dir = "ZeroEval-main/result_dirs/zebra-grid/"
results_by_model = {}

# Formats the columns
def formatter(x):
    if type(x) is str:
        x = x
    else: 
        x = round(x, 1)
    return x
 

def post_processing(df, column_names, rank_column=RANKING_COLUMN, ordered_columns=ORDERED_COLUMN_NAMES, click_url=True):
    for col in df.columns:
        if col == "Model" and click_url:
            df[col] = df[col].apply(lambda x: x.replace(x, make_clickable_model(x)))
        else:
            df[col] = df[col].apply(formatter) # For numerical values  
 
    df.rename(columns=column_names, inplace=True)
    list_columns = [col for col in ordered_columns if col in df.columns]
    df = df[list_columns]
    if rank_column in df.columns:
        df.sort_values(by=rank_column, inplace=True, ascending=False)
    return df
  

def load_all_data():
    global summary_file, result_dir
    with open(summary_file, "r") as f:
        model_summary = json.load(f)
    model_names = [model["Model"] for model in model_summary]
    for model_name in model_names:
        download_url = f"https://raw.githubusercontent.com/yuchenlin/ZeroEval/main/result_dirs/zebra-grid/{model_name}.json"
        output_file = os.path.join(result_dir, f"{model_name}.json")
        # mkdir -p result_dir if not exists 
        os.makedirs(result_dir, exist_ok=True)
        if not os.path.exists(output_file):
            os.system(f"wget {download_url} -O {output_file}")
            print(f"Downloaded {model_name}.json")
        with open(output_file, "r") as f:
            print(f"Loading {output_file}")
            results_by_model[model_name] = json.load(f) 
    
def get_random_item(model_name="random", size_H="random", size_W="random"):
    global summary_file, result_dir, results_by_model
    if results_by_model is None or len(results_by_model) == 0:
        load_all_data()
    if model_name == "random":
        model_name = random.choice(list(results_by_model.keys()))
    data = results_by_model[model_name]
    random.shuffle(data)
    selected_item = None
    prediction_table = None  
    prediction_reasoning = None 
    id_to_item = {}
    for item in data:
        id_to_item[item["id"]] = item
    
    if size_H == "random":
        size_H_choice =  random.choice(list(range(2, 7)))
    else:
        size_H_choice = size_H
    if size_W == "random":
        size_W_choice =  random.choice(list(range(2, 7)))
    else:
        size_W_choice = size_W
    ok_ids = [id for id in id_to_item if id_to_item[id]["size"].startswith(f"{size_H_choice}*{size_W_choice}")] 
    for ok_id in ok_ids:
        item = id_to_item[ok_id] 
        prediction_str = item["output"][0]
        prediction_json = extract_last_complete_json(prediction_str)
        if prediction_json is None or "solution" not in prediction_json:  
            continue 
        if "child" in item["puzzle"].lower() or "mother" in item["puzzle"].lower():
            continue
        if "loves the spaghetti eater" in item["puzzle"].lower():
            continue 
        prediction_reasoning = prediction_json.get("reasoning", "")
        prediction_table = prediction_json["solution"]
        if prediction_table is not None and "House 1" in prediction_table:
            selected_item = item
            break 

    if selected_item is None:
        # selected_item = random.choice(data)
        print("No item found!")
        return None 

    explore_item = {}
    explore_item["id"] = selected_item["id"]
    explore_item["Model"] = model_name
    explore_item["size"] = selected_item["size"]
    explore_item["puzzle"] = selected_item["puzzle"]
    explore_item["solution"] = prediction_table
    explore_item["reasoning"] = prediction_reasoning

    headers = ["Houses"] + list(prediction_table["House 1"].keys())
    rows = []
    for row_id in range(len(prediction_table)):
        row = [row_id+1] 
        for feature in headers[1:]:
            row.append(prediction_table[f"House {row_id+1}"][feature])
        rows.append(row)
    table_md = tabulate(rows, headers=headers, tablefmt="github")
    explore_item["solution_table_md"] = table_md

    this_total_cells, this_correct_cells, truth_solution_table = eval_each_puzzle(explore_item["id"], prediction_table)
    # print(table_md)
    explore_item["correct_cells"] = this_correct_cells
    explore_item["total_cells"] = this_total_cells
    explore_item["truth_solution_table"]  = tabulate(truth_solution_table["rows"], headers=truth_solution_table["header"], tablefmt="github")
    return explore_item


if __name__ == "__main__":
    load_all_data()
    print("All data downloaded!")
    print(json.dumps(get_random_item(model_name="gemini-1.5-pro", size_H="2", size_W="5"), indent=2))