import sys

import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# fmt: off
type_emoji = {
    "RTL-Specific": "🔴",
    "General": "🟢",
    "Coding": "🔵"
}
# fmt: on


def model_hyperlink(link, model_name, release):
    if release == "V1":
        return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
    else:
        return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a> <span style="font-variant: all-small-caps; font-weight: 600">new</span>'


def handle_special_cases(benchmark, metric):
    if metric == "Exact Matching (EM)":
        benchmark = "RTL-Repo"
    elif benchmark == "RTL-Repo":
        metric = "Exact Matching (EM)"
    return benchmark, metric


def filter_RTLRepo(subset: pd.DataFrame) -> pd.DataFrame:
    subset = subset.drop(subset[subset.Score < 0.0].index)
    details = subset[
        ["Model", "Model URL", "Model Type", "Params", "Release"]
    ].drop_duplicates("Model")
    filtered_df = subset[["Model", "Score"]].rename(
        columns={"Score": "Exact Matching (EM)"}
    )
    filtered_df = pd.merge(filtered_df, details, on="Model", how="left")
    filtered_df["Model"] = filtered_df.apply(
        lambda row: model_hyperlink(row["Model URL"], row["Model"], row["Release"]),
        axis=1,
    )
    filtered_df["Type"] = filtered_df["Model Type"].map(lambda x: type_emoji.get(x, ""))
    filtered_df = filtered_df[["Type", "Model", "Params", "Exact Matching (EM)"]]
    filtered_df = filtered_df.sort_values(
        by="Exact Matching (EM)", ascending=False
    ).reset_index(drop=True)
    return filtered_df


def filter_bench(subset: pd.DataFrame, df_agg=None, agg_column=None) -> pd.DataFrame:
    details = subset[
        ["Model", "Model URL", "Model Type", "Params", "Release"]
    ].drop_duplicates("Model")
    if "RTLLM" in subset["Benchmark"].unique():
        pivot_df = (
            subset.pivot_table(
                index="Model", columns="Metric", values="Score", aggfunc=custom_agg_s2r
            )
            .reset_index()
            .round(2)
        )
    else:
        pivot_df = (
            subset.pivot_table(
                index="Model", columns="Metric", values="Score", aggfunc=custom_agg_cc
            )
            .reset_index()
            .round(2)
        )

    # if df_agg is not None and agg_column is not None and agg_column in df_agg.columns:
    #     agg_data = df_agg[["Model", agg_column]].rename(
    #         columns={agg_column: "Aggregated ⬆️"}
    #     )
    #     pivot_df = pd.merge(pivot_df, agg_data, on="Model", how="left")
    # else:  # fallback
    #     pivot_df["Aggregated ⬆️"] = pivot_df.mean(axis=1, numeric_only=True).round(2)

    pivot_df = pd.merge(pivot_df, details, on="Model", how="left")
    pivot_df["Model"] = pivot_df.apply(
        lambda row: model_hyperlink(row["Model URL"], row["Model"], row["Release"]),
        axis=1,
    )
    pivot_df["Type"] = pivot_df["Model Type"].map(lambda x: type_emoji.get(x, ""))
    pivot_df["Post-Synthesis (PSQ)"] = (
        pivot_df[["Power", "Performance", "Area"]].mean(axis=1).round(2)
    )

    pivot_df.rename(
        columns={
            "Params": "Parameters (B)",
            "Syntax (STX)": "Syntax",
            "Functionality (FNC)": "Functionality",
            "Synthesis (SYN)": "Synthesis",
            "Post-Synthesis (PSQ)": "Post-Synthesis",
        },
        inplace=True,
    )
    columns_order = [
        "Type",
        "Model",
        "Parameters (B)",
        "Syntax",
        "Functionality",
        "Synthesis",
        "Post-Synthesis",
    ]
    pivot_df = pivot_df[[col for col in columns_order if col in pivot_df.columns]]
    pivot_df = pivot_df.sort_values(by="Functionality", ascending=False).reset_index(
        drop=True
    )
    return pivot_df


def custom_agg_s2r(vals):
    if len(vals) == 2:
        s2r_val = vals.iloc[0]
        rtllm_val = vals.iloc[1]
        w1 = 155
        w2 = 47
        result = (w1 * s2r_val + w2 * rtllm_val) / (w1 + w2)
    else:
        result = vals.iloc[0]
    return round(result, 2)


def custom_agg_cc(vals):
    if len(vals) == 2:
        veval_val = vals.iloc[0]
        vgen_val = vals.iloc[1]
        w1 = 155
        w2 = 17
        result = (w1 * veval_val + w2 * vgen_val) / (w1 + w2)
    else:
        result = vals.iloc[0]
    return round(result, 2)


def filter_bench_all(
    subset: pd.DataFrame, df_agg=None, agg_column=None
) -> pd.DataFrame:
    details = subset[
        ["Model", "Model URL", "Model Type", "Params", "Release"]
    ].drop_duplicates("Model")
    if "RTLLM" in subset["Benchmark"].unique():
        pivot_df = (
            subset.pivot_table(
                index="Model", columns="Metric", values="Score", aggfunc=custom_agg_s2r
            )
            .reset_index()
            .round(2)
        )
    else:
        pivot_df = (
            subset.pivot_table(
                index="Model", columns="Metric", values="Score", aggfunc=custom_agg_cc
            )
            .reset_index()
            .round(2)
        )

    pivot_df = pd.merge(pivot_df, details, on="Model", how="left")
    pivot_df["Model"] = pivot_df.apply(
        lambda row: model_hyperlink(row["Model URL"], row["Model"], row["Release"]),
        axis=1,
    )
    pivot_df["Type"] = pivot_df["Model Type"].map(lambda x: type_emoji.get(x, ""))
    pivot_df["Post-Synthesis Quality"] = (
        pivot_df[["Power", "Performance", "Area"]].mean(axis=1).round(2)
    )

    pivot_df.rename(
        columns={
            "Params": "Parameters (B)",
            "Exact Matching (EM)": "EM",
            "Syntax (STX)": "Syntax",
            "Functionality (FNC)": "Functionality",
            "Synthesis (SYN)": "Synthesis",
            "Post-Synthesis Quality": "Post-Synthesis",
        },
        inplace=True,
    )

    columns_order = [
        "Type",
        "Model",
        "Parameters (B)",
        "Syntax",
        "Functionality",
        "Synthesis",
        "Post-Synthesis",
    ]
    pivot_df = pivot_df[[col for col in columns_order if col in pivot_df.columns]]
    pivot_df = pivot_df.sort_values(by="Functionality", ascending=False).reset_index(
        drop=True
    )
    return pivot_df