import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import xgboost as xgb
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR

os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"

import models.fm4m as fm4m


# Function to create model based on user input
def _create_model(
    model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None
):
    if model_name == "XGBClassifier":
        model = xgb.XGBClassifier(
            objective='binary:logistic',
            eval_metric='auc',
            max_depth=max_depth,
            n_estimators=n_estimators,
            alpha=alpha,
        )
    elif model_name == "SVR":
        model = SVR(degree=degree, kernel=kernel)
    elif model_name == "Kernel Ridge":
        model = KernelRidge(alpha=alpha, degree=degree, kernel=kernel)
    elif model_name == "Linear Regression":
        model = LinearRegression()
    elif model_name == "Default - Auto":
        return "Default Settings"
    else:
        return "Model not supported."

    return f"{model_name} * {model.get_params()}"


# Function to handle model creation based on input parameters
def create_downstream_model(state):
    model_name, max_depth, n_estimators, alpha, degree, kernel = (
        state["model_name"],
        state.get("max_depth"),
        state.get("n_estimators"),
        state.get("alpha"),
        state.get("degree"),
        state.get("kernel"),
    )
    if model_name == "XGBClassifier":
        return _create_model(
            model_name,
            max_depth=max_depth,
            n_estimators=n_estimators,
            alpha=alpha,
        )
    elif model_name == "SVR":
        return _create_model(model_name, degree=degree, kernel=kernel)
    elif model_name == "Kernel Ridge":
        return _create_model(model_name, alpha=alpha, degree=degree, kernel=kernel)
    elif model_name == "Linear Regression":
        return _create_model(model_name)
    elif model_name == "Default - Auto":
        return _create_model(model_name)


# Function to display evaluation score
def display_eval(selected_models, dataset, task_type, state, plot_state):
    downstream = create_downstream_model(state)
    state = plot_state
    result = None

    try:
        downstream_model = downstream.split("*")[0].lstrip()
        downstream_model = downstream_model.rstrip()
        hyp_param = downstream.split("*")[-1].lstrip()
        hyp_param = hyp_param.rstrip()
        hyp_param = hyp_param.replace("nan", "float('nan')")
        params = eval(hyp_param)
    except:
        downstream_model = downstream.split("*")[0].lstrip()
        downstream_model = downstream_model.rstrip()
        params = None

    try:
        if not selected_models:
            return "Please select at least one enabled model."

        if len(selected_models) > 1:
            if task_type == "Classification":
                if downstream_model == "Default Settings":
                    downstream_model = "DefaultClassifier"
                    params = None
                (
                    result,
                    state["roc_auc"],
                    state["fpr"],
                    state["tpr"],
                    state["x_batch"],
                    state["y_batch"],
                ) = fm4m.multi_modal(
                    model_list=selected_models,
                    downstream_model=downstream_model,
                    params=params,
                    dataset=dataset,
                )

            elif task_type == "Regression":
                if downstream_model == "Default Settings":
                    downstream_model = "DefaultRegressor"
                    params = None
                (
                    result,
                    state["RMSE"],
                    state["y_batch_test"],
                    state["y_prob"],
                    state["x_batch"],
                    state["y_batch"],
                ) = fm4m.multi_modal(
                    model_list=selected_models,
                    downstream_model=downstream_model,
                    params=params,
                    dataset=dataset,
                )

        else:
            if task_type == "Classification":
                if downstream_model == "Default Settings":
                    downstream_model = "DefaultClassifier"
                    params = None
                (
                    result,
                    state["roc_auc"],
                    state["fpr"],
                    state["tpr"],
                    state["x_batch"],
                    state["y_batch"],
                ) = fm4m.single_modal(
                    model=selected_models[0],
                    downstream_model=downstream_model,
                    params=params,
                    dataset=dataset,
                )

            elif task_type == "Regression":
                if downstream_model == "Default Settings":
                    downstream_model = "DefaultRegressor"
                    params = None
                (
                    result,
                    state["RMSE"],
                    state["y_batch_test"],
                    state["y_prob"],
                    state["x_batch"],
                    state["y_batch"],
                ) = fm4m.single_modal(
                    model=selected_models[0],
                    downstream_model=downstream_model,
                    params=params,
                    dataset=dataset,
                )

    except Exception as e:
        return f"An error occurred: {e}"
    return result or "Data & Model Setting is incorrect"


# Function to handle plot display
def display_plot(plot_type, state):
    fig, ax = plt.subplots()

    if plot_type == "Latent Space":
        x_batch, y_batch = state.get("x_batch"), state.get("y_batch")
        ax.set_title("T-SNE Plot")
        class_0 = x_batch
        class_1 = y_batch

        plt.scatter(class_1[:, 0], class_1[:, 1], c='red', label='Class 1')
        plt.scatter(class_0[:, 0], class_0[:, 1], c='blue', label='Class 0')

        ax.set_xlabel('Feature 1')
        ax.set_ylabel('Feature 2')
        ax.set_title('Dataset Distribution')

    elif plot_type == "ROC-AUC":
        roc_auc, fpr, tpr = state.get("roc_auc"), state.get("fpr"), state.get("tpr")
        ax.set_title("ROC-AUC Curve")
        try:
            ax.plot(
                fpr,
                tpr,
                color='darkorange',
                lw=2,
                label=f'ROC curve (area = {roc_auc:.4f})',
            )
            ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
            ax.set_xlim([0.0, 1.0])
            ax.set_ylim([0.0, 1.05])
        except:
            pass
        ax.set_xlabel('False Positive Rate')
        ax.set_ylabel('True Positive Rate')
        ax.set_title('Receiver Operating Characteristic')
        ax.legend(loc='lower right')

    elif plot_type == "Parity Plot":
        RMSE, y_batch_test, y_prob = (
            state.get("RMSE"),
            state.get("y_batch_test"),
            state.get("y_prob"),
        )
        ax.set_title("Parity plot")

        # change format
        try:
            print(y_batch_test)
            print(y_prob)
            y_batch_test = np.array(y_batch_test, dtype=float)
            y_prob = np.array(y_prob, dtype=float)
            ax.scatter(
                y_batch_test,
                y_prob,
                color="blue",
                label=f"Predicted vs Actual (RMSE: {RMSE:.4f})",
            )
            min_val = min(min(y_batch_test), min(y_prob))
            max_val = max(max(y_batch_test), max(y_prob))
            ax.plot([min_val, max_val], [min_val, max_val], 'r-')

        except:
            y_batch_test = []
            y_prob = []
            RMSE = None
            print(y_batch_test)
            print(y_prob)

        ax.set_xlabel('Actual Values')
        ax.set_ylabel('Predicted Values')

        ax.legend(loc='lower right')
    return fig


# Function to handle evaluation and logging
def evaluate_and_log(selected_models, dataset, task_type, log_df, state):
    log_df = log_df[log_df['id'] != '']
    id = len(log_df) + 1
    plot_state = {"roc_auc": None, "RMSE": None, "x_batch": None}
    state["results"][id] = plot_state
    eval_output = display_eval(selected_models, dataset, task_type, state, plot_state)

    new_entry_df = pd.DataFrame(
        [
            {
                "id": id,
                "Model": " + ".join(selected_models),
                "Score": eval_output.replace(" Score", ""),
            }
        ]
    )
    log_df = pd.concat([log_df, new_entry_df])
    return log_df