import streamlit as st
import os
import numpy as np
import pandas as pd
from logger import logger
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import json
from utils import fs,validate_email
from enums import SAVE_PATH, ELO_JSON_PATH, ELO_CSV_PATH, EMAIL_PATH


def write_email(email):
    if fs.exists(EMAIL_PATH):
        with fs.open(EMAIL_PATH, 'rb') as f:
            existing_content = f.read().decode('utf-8')
    else:
        existing_content = ''

    new_content = existing_content + email + '\n'

    with fs.open(EMAIL_PATH, 'wb') as f:
        f.write(new_content.encode('utf-8'))

def get_model_abbreviation(model_name):
    abbrev_map = {
        'Ori Apex': 'Ori Apex',
        'Ori Apex XT': 'Ori Apex XT',
        'deepgram': 'Deepgram',
        'Ori Swift': 'Ori Swift',
        'Ori Prime': 'Ori Prime',
        'azure' : 'Azure',
        'sarvam':'Sarvam'
    }
    return abbrev_map.get(model_name, model_name)


def calculate_metrics(df):
    models = ['Ori Apex', 'Ori Apex XT', 'deepgram', 'Ori Swift', 'Ori Prime', 'azure','sarvam']
    metrics = {}

    for model in models:
        appearances = df[f'{model}_appearance'].sum()
        wins = df[f'{model}_score'].sum()
        durations = df[df[f'{model}_appearance'] == 1][f'{model}_duration']

        if appearances > 0:
            win_rate = (wins / appearances) * 100
            avg_duration = durations.mean()
            duration_std = durations.std()
        else:
            win_rate = 0
            avg_duration = 0
            duration_std = 0

        metrics[model] = {
            'appearances': appearances,
            'wins': wins,
            'win_rate': win_rate,
            'avg_response_time': avg_duration,
            'response_time_std': duration_std
        }

    return metrics

def create_win_rate_chart(metrics):
    models = list(metrics.keys())
    win_rates = [metrics[model]['win_rate'] for model in models]

    fig = go.Figure(data=[
        go.Bar(
            x=[get_model_abbreviation(model) for model in models],
            y=win_rates,
            text=[f'{rate:.1f}%' for rate in win_rates],
            textposition='auto',
            hovertext=models
        )
    ])

    fig.update_layout(
        title='Win Rate by Model',
        xaxis_title='Model',
        yaxis_title='Win Rate (%)',
        yaxis_range=[0, 100]
    )

    return fig

def create_appearance_chart(metrics):
    models = list(metrics.keys())
    appearances = [metrics[model]['appearances'] for model in models]

    fig = px.pie(
        values=appearances,
        names=[get_model_abbreviation(model) for model in models],
        title='Model Appearances Distribution',
        # hover_data=[models]
    )

    return fig

def create_head_to_head_matrix(df):
    models = ['Ori Apex', 'Ori Apex XT', 'deepgram', 'Ori Swift', 'Ori Prime', 'azure','sarvam']
    matrix = np.zeros((len(models), len(models)))

    for i, model1 in enumerate(models):
        for j, model2 in enumerate(models):
            if i != j:
                matches = df[
                    (df[f'{model1}_appearance'] == 1) &
                    (df[f'{model2}_appearance'] == 1)
                ]
                if len(matches) > 0:
                    win_rate = (matches[f'{model1}_score'].sum() / len(matches)) * 100
                    matrix[i][j] = win_rate

    fig = go.Figure(data=go.Heatmap(
        z=matrix,
        x=[get_model_abbreviation(model) for model in models],
        y=[get_model_abbreviation(model) for model in models],
        text=[[f'{val:.1f}%' if val > 0 else '' for val in row] for row in matrix],
        texttemplate='%{text}',
        colorscale='RdYlBu',
        zmin=0,
        zmax=100
    ))

    fig.update_layout(
        title='Head-to-Head Win Rates',
        xaxis_title='Opponent Model',
        yaxis_title='Model'
    )

    return fig

def create_elo_chart(df):
    fig = make_subplots(rows=1, cols=1,
                    row_heights=[0.7])

    for column in df.columns:
        fig.add_trace(
            go.Scatter(
                x=list(range(len(df))),
                y=df[column],
                name=column,
                mode='lines+markers'
            ),
            row=1, col=1
        )

    fig.update_layout(
        title='Model ELO Ratings Analysis',
        showlegend=True,
        hovermode='x unified'
    )

    fig.update_xaxes(title_text='Match Number', row=1, col=1)

    return fig

def create_metric_container(label, value, full_name=None):
    container = st.container()
    with container:
        st.markdown(f"**{label}**")
        if full_name:
            st.markdown(f"<h3 style='margin-top: 0;'>{value}</h3>", unsafe_allow_html=True)
            st.caption(f"Full name: {full_name}")
        else:
            st.markdown(f"<h3 style='margin-top: 0;'>{value}</h3>", unsafe_allow_html=True)

def on_refresh_click():
    st.toast("Refreshing data... please wait",icon="🔄")
    with fs.open(SAVE_PATH, 'rb') as f:
        st.session_state.df = pd.read_csv(f)

    try:
        with fs.open(ELO_JSON_PATH,'r') as f:
            st.session_state.elo_json = json.load(f)
    except Exception as e:
        logger.error("Error while reading elo json file %s",e)
        st.session_state.elo_json = None

    try:
        with fs.open(ELO_CSV_PATH,'rb') as f:
            st.session_state.elo_df = pd.read_csv(f)
    except Exception as e:
        logger.error("Error while reading elo csv file %s",e)
        st.session_state.elo_df = None

def dashboard():
    st.title('Model Arena Scoreboard')

    if "df" not in st.session_state:
        with fs.open(SAVE_PATH, 'rb') as f:
            st.session_state.df = pd.read_csv(f)
    if "elo_json" not in st.session_state:
        with fs.open(ELO_JSON_PATH,'r') as f:
            elo_json = json.load(f)
            st.session_state.elo_json = elo_json
    if "elo_df" not in st.session_state:
        with fs.open(ELO_CSV_PATH,'rb') as f:
            elo_df = pd.read_csv(f)
            st.session_state.elo_df = elo_df

    st.button("🔄 Refresh",on_click=on_refresh_click,key="refresh_btn")

    if len(st.session_state.df) != 0:
        metrics = calculate_metrics(st.session_state.df)

        MODEL_DESCRIPTIONS = {
            "Ori Prime": "Foundational, large, and stable.",
            "Ori Swift": "Lighter and faster than Ori Prime.",
            "Ori Apex": "The top-performing model, fast and stable.",
            "Ori Apex XT": "Enhanced with more training, though slightly less stable than Ori Apex.",
            "Deepgram" : "Deepgram Nova-2 API",
            "Azure" : "Azure Speech Services API",
            "Sarvam": "Sarvam AI saarika:v2 API"
        }

        st.header('Model Descriptions')

        cols = st.columns(2)
        for idx, (model, description) in enumerate(MODEL_DESCRIPTIONS.items()):
            with cols[idx % 2]:
                st.markdown(f"""
                    <div style='padding: 1rem; border: 1px solid #e1e4e8; border-radius: 6px; margin-bottom: 1rem;'>
                        <h3 style='margin: 0; margin-bottom: 0.5rem;'>{model}</h3>
                        <p style='margin: 0; color: #6e7681;'>{description}</p>
                    </div>
                    """, unsafe_allow_html=True)

        st.header('Overall Performance')

        col1, col2, col3= st.columns(3)

        with col1:
            create_metric_container("Total Matches", len(st.session_state.df))

        # best_model = max(metrics.items(), key=lambda x: x[1]['win_rate'])[0]
        best_model = max(st.session_state.elo_json.items(), key=lambda x: x[1])[0] if st.session_state.elo_json else max(metrics.items(), key=lambda x: x[1]['win_rate'])[0]
        with col2:
            create_metric_container(
                "Best Model",
                get_model_abbreviation(best_model),
                full_name=best_model
            )

        most_appearances = max(metrics.items(), key=lambda x: x[1]['appearances'])[0]
        with col3:
            create_metric_container(
                "Most Used",
                get_model_abbreviation(most_appearances),
                full_name=most_appearances
            )

        metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
        metrics_df['win_rate'] = metrics_df['win_rate'].round(2)
        metrics_df.drop(["avg_response_time","response_time_std"],axis=1,inplace=True)
        metrics_df.index = [get_model_abbreviation(model) for model in metrics_df.index]
        st.dataframe(metrics_df,use_container_width=True)

        st.header('Win Rates')
        win_rate_chart = create_win_rate_chart(metrics)
        st.plotly_chart(win_rate_chart, use_container_width=True)

        st.header('Appearance Distribution')
        appearance_chart = create_appearance_chart(metrics)
        st.plotly_chart(appearance_chart, use_container_width=True)

        if st.session_state.elo_json is not None and st.session_state.elo_df is not None:
            st.header('Elo Ratings')
            st.dataframe(pd.DataFrame(st.session_state.elo_json,index=[0]),use_container_width=True)
            elo_progression_chart = create_elo_chart(st.session_state.elo_df)
            st.plotly_chart(elo_progression_chart, use_container_width=True)

        st.header('Head-to-Head Analysis')
        matrix_chart = create_head_to_head_matrix(st.session_state.df)
        st.plotly_chart(matrix_chart, use_container_width=True)

    else:
        st.write("No Data to show")

if __name__ == "__main__":
    if 'logged_in' not in st.session_state:
        st.session_state.logged_in = False

    if st.session_state.logged_in:
        dashboard()
    else:
        with st.form("contact_us_form"):
            st.subheader("Please enter your email to view the scoreboard")

            email = st.text_input("Email")

            submit_button = st.form_submit_button("Submit")

        if submit_button:
            if not email:
                st.error("Please fill in all fields")
            else:
                if not validate_email(email):
                    st.error("Please enter a valid email address")
                else:
                    st.session_state.logged_in = True
                    st.session_state.user_email = email
                    write_email(st.session_state.user_email)
                    st.success("Thanks for submitting your email")
                    dashboard()