import gradio as gr
import pandas as pd
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
import statsmodels.api as sm
import scipy.optimize as opt
import csv
import os

# Global state
data_df     = None
poly        = None
model_power = None
model_er    = None

def hex_to_int(x):
    """Safely parse a hex string (with or without '0x') to int."""
    try:
        return int(str(x).strip(), 16)
    except:
        return np.nan

# Panel 1: Load & Preview
def load_and_preview(file, n):
    if file is None:
        # Hide the preview until we have data
        return gr.update(visible=False), "▶️ Please upload an .xlsx or .csv file"
    global data_df
    try:
        filename = file.name
        ext = os.path.splitext(filename)[1].lower()
        if ext in ['.xlsx', '.xls']:
            # Excel format
            xls = pd.ExcelFile(filename)
            rows = []
            for sheet in xls.sheet_names:
                if sheet.startswith("T3"):
                    df = pd.read_excel(xls, sheet_name=sheet, header=None)
                    h0 = df.iloc[1].ffill()
                    h1 = df.iloc[2].fillna("")
                    cols = [
                        (f"{a} {b}".strip() if b else str(a).strip())
                        for a, b in zip(h0, h1)
                    ]
                    df.columns = cols

                    raw = df.iloc[3:][
                        ["Setting Power", "Setting ER", "EA-4000 Power", "EA-4000 ER"]
                    ].copy()
                    raw["Setting Power"] = raw["Setting Power"].ffill()
                    raw["power_hex"]     = raw["Setting Power"]
                    raw["er_hex"]        = raw["Setting ER"]
                    raw["power_dec"] = raw["power_hex"].apply(hex_to_int)
                    raw["er_dec"]    = raw["er_hex"].apply(hex_to_int)
                    raw["power_meas"] = pd.to_numeric(raw["EA-4000 Power"], errors="coerce")
                    raw["er_meas"]    = pd.to_numeric(raw["EA-4000 ER"],    errors="coerce")
                    raw["Device"]     = sheet
                    valid = raw[raw["power_meas"].notna()]

                    rows.append(valid[[
                        "Device","power_hex","er_hex",
                        "power_dec","er_dec","power_meas","er_meas"
                    ]])
            if not rows:
                raise ValueError("No valid sheets (prefix 'T3') found in Excel file.")
            data_df = pd.concat(rows, ignore_index=True)
        elif ext == '.csv':
            # CSV format (exported)
            df = pd.read_csv(filename, quoting=csv.QUOTE_ALL, escapechar='\\')
            required = {"Device","power_hex","er_hex","power_dec","er_dec","power_meas","er_meas"}
            if not required.issubset(df.columns):
                missing = required - set(df.columns)
                raise ValueError(f"CSV missing required columns: {missing}")
            data_df = df.copy()
            # ensure proper types
            data_df['power_dec']  = data_df['power_hex'].apply(hex_to_int)
            data_df['er_dec']     = data_df['er_hex'].apply(hex_to_int)
            data_df['power_meas'] = pd.to_numeric(data_df['power_meas'], errors='coerce')
            data_df['er_meas']    = pd.to_numeric(data_df['er_meas'],    errors='coerce')
        else:
            raise ValueError(f"Unsupported file type: {ext}")

        preview_df = data_df.head(int(n))
        # Un-hide and populate the preview grid
        return gr.update(value=preview_df, visible=True), "✅ Data loaded successfully"
    except Exception as e:
        # On error, keep it hidden
        return gr.update(visible=False), f"❌ {e}"

def export_csv():
    """Export the loaded training dataset to CSV for inspection."""
    global data_df
    if data_df is None:
        return gr.update(visible=False, value=None)
    path = "training_data.csv"
    # wrap every field in double-quotes so Excel won’t re-interpret it
    data_df.to_csv(path, index=False,
                  quoting=csv.QUOTE_ALL,
                  escapechar='\\')
    return gr.update(visible=True, value=path)

# Panel 2: Train Hierarchical Quadratic RSM
def train_model():
    global poly, model_power, model_er, data_df
    if data_df is None:
        return "❌ No data loaded"

    X   = data_df[["power_dec", "er_dec"]].values
    y_p = data_df["power_meas"].values
    y_e = data_df["er_meas"].values
    groups = data_df["Device"]

    poly = PolynomialFeatures(degree=2, include_bias=True)
    Xp   = poly.fit_transform(X)

    model_power = sm.MixedLM(endog=y_p, exog=Xp, groups=groups).fit()
    model_er    = sm.MixedLM(endog=y_e, exog=Xp, groups=groups).fit()

    pred_p = model_power.fittedvalues
    pred_e = model_er.fittedvalues
    r2p    = 1 - np.sum((y_p - pred_p)**2)/np.sum((y_p - y_p.mean())**2)
    r2e    = 1 - np.sum((y_e - pred_e)**2)/np.sum((y_e - y_e.mean())**2)
    rmse_p = np.sqrt(np.mean((y_p - pred_p)**2))
    rmse_e = np.sqrt(np.mean((y_e - pred_e)**2))

    return (
        f"✅ Trained hierarchical quadratic RSM\n"
        f"Power → R²={r2p:.3f}, RMSE={rmse_p:.3f}\n"
        f"ER    → R²={r2e:.3f}, RMSE={rmse_e:.3f}"
    )

# Panel 3: Calibrate & Predict
def calibrate_and_predict(calib_df, tp, te):
    global poly, model_power, model_er, data_df
    if poly is None:
        return {"error": "Model not trained"}

    df = calib_df  # already a pandas DataFrame
    samples = []
    for _, r in df.iterrows():
        phex = hex_to_int(r["power_hex"])
        ehex = hex_to_int(r["er_hex"])
        pm   = pd.to_numeric(r["power_meas"], errors="coerce")
        em   = pd.to_numeric(r["er_meas"],   errors="coerce")
        if not np.isnan(phex) and not np.isnan(ehex) and not np.isnan(pm) and not np.isnan(em):
            samples.append((phex, ehex, pm, em))

    if samples:
        Xc   = np.array([[p,e] for p,e,_,_ in samples])
        Xcp  = poly.transform(Xc)
        pred_p = model_power.predict(exog=Xcp)
        pred_e = model_er   .predict(exog=Xcp)
        offset_p = float(np.mean([pm - p for (_,_,pm,_), p in zip(samples, pred_p)]))
        offset_e = float(np.mean([em - e for (_,_,_,em), e in zip(samples, pred_e)]))
    else:
        offset_p = offset_e = 0.0

    p_min, p_max = int(data_df["power_dec"].min()), int(data_df["power_dec"].max())
    e_min, e_max = int(data_df["er_dec"].min()),     int(data_df["er_dec"].max())

    def obj(vars):
        x  = np.array(vars).reshape(1, -1)
        xp = poly.transform(x)
        p0 = model_power.predict(exog=xp)[0] + offset_p
        e0 = model_er   .predict(exog=xp)[0] + offset_e
        return (p0 - tp)**2 + (e0 - te)**2

    res = opt.minimize(
        obj,
        x0=[(p_min+p_max)/2, (e_min+e_max)/2],
        bounds=[(p_min, p_max), (e_min, e_max)]
    )
    ph, eh = map(int, np.round(res.x))

    return {
        "Power Setting (hex)": hex(ph),
        "ER Setting (hex)"   : hex(eh)
    }


with gr.Blocks() as demo:
    gr.Markdown("# Power and ER Calibration APP")

    with gr.Tab("1. Load Data"):
        file_in  = gr.File(label="Upload .xlsx or .csv")
        n_slider = gr.Slider(1, 2000, value=99, step=1, label="Rows to preview")
        preview  = gr.DataFrame(visible=False)
        status   = gr.Textbox()

        file_in.change(
            fn=load_and_preview,
            inputs=[file_in, n_slider],
            outputs=[preview, status]
        )

        export_btn = gr.Button("Export Training Dataset (CSV)")
        csv_file   = gr.File(label="Download CSV", visible=False)
        export_btn.click(
            fn=export_csv,
            inputs=None,
            outputs=csv_file
        )

    with gr.Tab("2. Train Model"):
        train_btn = gr.Button("Train RSM")
        train_out = gr.Textbox()
        train_btn.click(fn=train_model, inputs=None, outputs=train_out)

    with gr.Tab("3. Calibrate & Predict"):
        gr.Markdown("**Enter up to 5 calibration samples and target values**")
        calib_df = gr.DataFrame(
            headers=["power_hex", "er_hex", "power_meas", "er_meas"],
            row_count=5, col_count=4, interactive=True
        )
        tp       = gr.Number(value=2.5,  label="Target Power (dec)")
        te       = gr.Number(value=12.75, label="Target ER (dec)")
        pred_btn = gr.Button("Predict Settings")
        pred_out = gr.JSON(label="Predicted Settings")
        pred_btn.click(
            fn=calibrate_and_predict,
            inputs=[calib_df, tp, te],
            outputs=[pred_out]
        )

    demo.launch()