import gradio as gr
import pandas as pd
import plotly.express as px


from src.utils import process_arch


FLASHATTENTIONV2_DATA = [
    # open llm
    "Model 🤗",
    "Arch 🏛️",
    "DType 📥",
    "Backend 🏭",
    "Params (B)",
    "Open LLM Score (%)",
    # deployment settings
    "DType 📥",
    "Backend 🏭",
    "Quantization 🗜️",
    # primary measurements
    "Prefill Latency (s)",
    "Prefill Latency (s) FlashAttentionV2",
    "Decode Throughput (tokens/s)",
    "Decode Throughput (tokens/s) FlashAttentionV2",
    "E2E Throughput (tokens/s)",
    "E2E Throughput (tokens/s) FlashAttentionV2",
    # speedups
    "Prefill Latency Speedup (%)",
    "Decode Throughput Speedup (%)",
]


def get_fa2_df(llm_perf_df):
    fa2_df = llm_perf_df.copy()
    # process
    fa2_df["Arch 🏛️"] = fa2_df["Arch 🏛️"].apply(process_arch)
    # seperate original model experiments from FlashAttentionV2 experiments
    original_df = fa2_df[fa2_df["Optimization 🛠️"] == "None"]
    fa2_df = fa2_df[fa2_df["Optimization 🛠️"] == "FlashAttentionV2"]
    # merge the two dataframes
    fa2_df = pd.merge(
        original_df,
        fa2_df,
        on=["Model 🤗", "Quantization 🗜️"],
        suffixes=["", " FlashAttentionV2"],
    )
    # compute speedups
    fa2_df["Prefill Latency Speedup (%)"] = (
        (fa2_df["Prefill Latency (s)"] / fa2_df["Prefill Latency (s) FlashAttentionV2"]) * 100
    ).round(2)
    fa2_df["Decode Throughput Speedup (%)"] = (
        (fa2_df["Decode Throughput (tokens/s) FlashAttentionV2"] / fa2_df["Decode Throughput (tokens/s)"]) * 100
    ).round(2)

    # filter speedups > 1000%
    fa2_df = fa2_df[fa2_df["Prefill Latency Speedup (%)"] < 1000]
    fa2_df = fa2_df[fa2_df["Decode Throughput Speedup (%)"] < 1000]

    return fa2_df


def get_fa2_decode_fig(llm_perf_df):
    fa2_df = get_fa2_df(llm_perf_df)
    # plot
    decode_fig = px.box(
        fa2_df,
        x="Arch 🏛️",
        y="Decode Throughput Speedup (%)",
        color_discrete_sequence=px.colors.qualitative.Light24,
        custom_data=FLASHATTENTIONV2_DATA,
        color="Quantization 🗜️",
        points="all",
    )
    # add hover data
    decode_fig.update_traces(
        hovertemplate="<br>".join(
            [f"<b>{column}:</b> %{{customdata[{i}]}}" for i, column in enumerate(FLASHATTENTIONV2_DATA)]
        )
    )
    # add layout
    decode_fig.update_layout(
        title={
            "text": "Decode Throughput Speedup per Architecture",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
        xaxis_title="LLM Architecture",
        yaxis_title="Decode Speedup (%)",
        legend_title="Quantization Scheme",
        width=1200,
        height=600,
    )

    return decode_fig


def get_fa2_prefill_fig(llm_perf_df):
    fa2_df = get_fa2_df(llm_perf_df)
    # plot
    prefill_fig = px.box(
        fa2_df,
        x="Arch 🏛️",
        y="Prefill Latency Speedup (%)",
        color_discrete_sequence=px.colors.qualitative.Light24,
        custom_data=FLASHATTENTIONV2_DATA,
        color="Quantization 🗜️",
        points="all",
    )
    # add hover data
    prefill_fig.update_traces(
        hovertemplate="<br>".join(
            [f"<b>{column}:</b> %{{customdata[{i}]}}" for i, column in enumerate(FLASHATTENTIONV2_DATA)]
        )
    )
    # add layout
    prefill_fig.update_layout(
        title={
            "text": "Prefill Latency Speedup per Architecture",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
        xaxis_title="LLM Architecture",
        yaxis_title="Prefill Speedup (%)",
        legend_title="Quantization Scheme",
        width=1200,
        height=600,
    )

    return prefill_fig


def create_fa2_plots(llm_perf_df):
    # descriptive text
    gr.HTML("👆 Hover over the points 👆 for additional information.", elem_id="text")
    # get figures
    prefill_fig = get_fa2_prefill_fig(llm_perf_df)
    decode_fig = get_fa2_decode_fig(llm_perf_df)

    # create plots
    prefill_plot = gr.components.Plot(value=prefill_fig, elem_id="plot", show_label=False)
    decode_plot = gr.components.Plot(value=decode_fig, elem_id="plot", show_label=False)

    return prefill_plot, decode_plot