import os
import gradio as gr
import pandas as pd
import plotly.express as px
from apscheduler.schedulers.background import BackgroundScheduler

from src.assets.text_content import TITLE, INTRODUCTION_TEXT, SINGLE_A100_TEXT, CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT
from src.utils import restart_space, load_dataset_repo, make_clickable_model, make_clickable_score, submit_query
from src.assets.css_html_js import custom_css


LLM_PERF_LEADERBOARD_REPO = "optimum/llm-perf-leaderboard"
LLM_PERF_DATASET_REPO = "optimum/llm-perf-dataset"
OPTIMUM_TOKEN = os.environ.get("OPTIMUM_TOKEN", None)

COLUMNS_MAPPING = {
    "model": "Model 🤗",
    "backend.name": "Backend 🏭",
    "backend.torch_dtype": "Datatype 📥",
    "forward.peak_memory(MB)": "Peak Memory (MB) ⬇️",
    "generate.throughput(tokens/s)": "Throughput (tokens/s) ⬆️",
    "h4_score": "Average H4 Score ⬆️",
}
COLUMNS_DATATYPES = ["markdown", "str", "str", "number", "number", "markdown"]
SORTING_COLUMN = ["Throughput (tokens/s) ⬆️"]


llm_perf_dataset_repo = load_dataset_repo(LLM_PERF_DATASET_REPO, OPTIMUM_TOKEN)


def get_benchmark_df(benchmark):
    if llm_perf_dataset_repo:
        llm_perf_dataset_repo.git_pull()

    # load
    bench_df = pd.read_csv(
        f"./llm-perf-dataset/reports/{benchmark}.csv")
    scores_df = pd.read_csv(
        f"./llm-perf-dataset/reports/additional_data.csv")
    bench_df = bench_df.merge(scores_df, on="model", how="left")

    # preprocess
    bench_df["model"] = bench_df["model"].apply(make_clickable_model)
    bench_df["h4_score"] = bench_df["h4_score"].apply(make_clickable_score)
    # filter
    bench_df = bench_df[list(COLUMNS_MAPPING.keys())]
    # rename
    bench_df.rename(columns=COLUMNS_MAPPING, inplace=True)
    # sort
    bench_df.sort_values(by=SORTING_COLUMN, ascending=False, inplace=True)

    return bench_df


# Dataframes
single_A100_df = get_benchmark_df(benchmark="1xA100-80GB")


def get_benchmark_plot(benchmark):
    if llm_perf_dataset_repo:
        llm_perf_dataset_repo.git_pull()

    # load
    bench_df = pd.read_csv(
        f"./llm-perf-dataset/reports/{benchmark}.csv")
    scores_df = pd.read_csv(
        f"./llm-perf-dataset/reports/additional_data.csv")
    bench_df = bench_df.merge(scores_df, on="model", how="left")

    bench_df = bench_df[bench_df["generate.latency(s)"] < 100]

    fig = px.scatter(
        bench_df, x="h4_score", y="generate.latency(s)",
        color='model_type', symbol='backend.name', size='forward.peak_memory(MB)',
        custom_data=['model', 'backend.name', 'backend.torch_dtype',
                     'forward.peak_memory(MB)', 'generate.throughput(tokens/s)'],
    )

    fig.update_layout(
        title={
            'text': "Model Score vs. Latency vs. Memory",
            'y': 0.95, 'x': 0.5,
            'xanchor': 'center',
            'yanchor': 'top'
        },
        xaxis_title="Average H4 Score",
        yaxis_title="Latency per 1000 Tokens (s)",
        legend_title="Model Type, Backend",
        width=1200,
        height=600,
    )

    fig.update_traces(
        hovertemplate="<br>".join([
            "Model: %{customdata[0]}",
            "Backend: %{customdata[1]}",
            "Datatype: %{customdata[2]}",
            "Peak Memory (MB): %{customdata[3]}",
            "Throughput (tokens/s): %{customdata[4]}",
            "Latency per 1000 Tokens (s): %{y}",
            "Average H4 Score: %{x}"
        ])
    )

    return fig


# Plots
single_A100_plot = get_benchmark_plot(benchmark="1xA100-80GB")

# Demo interface
demo = gr.Blocks(css=custom_css)
with demo:
    # leaderboard title
    gr.HTML(TITLE)

    # introduction text
    gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")

    # control panel title
    gr.HTML("<h2>Control Panel 🎛️</h2>")

    # control panel interface
    with gr.Row():
        search_bar = gr.Textbox(
            label="Model 🤗",
            info="🔍 Search for a model name",
            elem_id="search-bar",
        )
        backend_checkboxes = gr.CheckboxGroup(
            label="Backends 🏭",
            choices=["pytorch", "onnxruntime"],
            value=["pytorch", "onnxruntime"],
            info="☑️ Select the backends",
            elem_id="backend-checkboxes",
        )
        datatype_checkboxes = gr.CheckboxGroup(
            label="Datatypes 📥",
            choices=["float32", "float16"],
            value=["float32", "float16"],
            info="☑️ Select the load datatypes",
            elem_id="datatype-checkboxes",
        )
        threshold_slider = gr.Slider(
            label="Average H4 Score 📈",
            info="lter by minimum average H4 score",
            value=0.0,
            elem_id="threshold-slider",
        )

    with gr.Row():
        submit_button = gr.Button(
            value="Submit 🚀",
            elem_id="submit-button",
        )

    # leaderboard tabs
    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem("🖥️ A100-80GB Leaderboard 🏆", id=0):
            gr.HTML(SINGLE_A100_TEXT)

            # Original leaderboard table
            single_A100_leaderboard = gr.components.Dataframe(
                value=single_A100_df,
                datatype=COLUMNS_DATATYPES,
                headers=list(COLUMNS_MAPPING.values()),
                elem_id="1xA100-table",
            )
            # Dummy dataframe for search
            single_A100_for_search = gr.components.Dataframe(
                value=single_A100_df,
                datatype=COLUMNS_DATATYPES,
                headers=list(COLUMNS_MAPPING.values()),
                max_rows=None,
                visible=False,
            )

        submit_button.click(
            submit_query,
            [
                search_bar, backend_checkboxes, datatype_checkboxes, threshold_slider,
                single_A100_for_search
            ],
            [single_A100_leaderboard]
        )

        with gr.TabItem("🖥️ A100-80GB Plot 📊", id=1):
            # Original leaderboard plot
            gr.HTML(SINGLE_A100_TEXT)

            # Original leaderboard plot
            single_A100_plotly = gr.components.Plot(
                value=single_A100_plot,
                elem_id="1xA100-plot",
                show_label=False,
            )

    with gr.Row():
        with gr.Accordion("📙 Citation", open=False):
            citation_button = gr.Textbox(
                value=CITATION_BUTTON_TEXT,
                label=CITATION_BUTTON_LABEL,
                elem_id="citation-button",
            ).style(show_copy_button=True)


# Restart space every hour
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=3600,
                  args=[LLM_PERF_LEADERBOARD_REPO, OPTIMUM_TOKEN])
scheduler.start()

# Launch demo
demo.queue(concurrency_count=40).launch()