import streamlit as st
import pandas as pd
import numpy as np
import yfinance as yf
from datetime import date, timedelta
from neuralprophet import NeuralProphet
from keras.models import load_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error
import plotly.graph_objs as go
import pickle

# ================= CONFIGURATION GÉNÉRALE ================= #
st.set_page_config(page_title="Tesla Forecast AI", layout="wide")
st.title("CC RECURENT NEURAL NETWORK")
st.subheader("📈 Prévision du cours de l'action de Tesla 🚗")
st.markdown("Une application d'IA qui utilise **NeuralProphet** et **LSTM** pour prédire la valeur de l’action Tesla (`Close`) dans les prochains jours.")

# ================= CHARGEMENT DES MODÈLES ================= #
@st.cache_resource
def load_lstm_model(path="best_model.keras"):
    return load_model(path)

@st.cache_resource
def load_neuralprophet_model(path="neuralprophet_model.pkl"):
    with open(path, "rb") as f:
        return pickle.load(f)

lstm_model = load_lstm_model()
neural_model = load_neuralprophet_model()

# ================= SIDEBAR ================= #
st.sidebar.header("Modèle de prevision")
model_choice = st.sidebar.selectbox("Selectionnez un Modèle à utiliser", ["LSTM", "NeuralProphet"])
n_days = st.sidebar.slider("Nombre de jours à prédire", 1, 30, 21)

# ================= FONCTION POUR LES DONNÉES ================= #
@st.cache_data
def load_data():
    df = yf.download("TSLA", start="2021-04-01", end=date.today())
    df.reset_index(inplace=True)
    df = df[["Date", "Close"]]
    df.columns = ["ds", "y"]
    df["ds"] = pd.to_datetime(df["ds"])
    df = df.sort_values("ds")

    # Normalisation
    scaler = MinMaxScaler()
    df["y_scaled"] = scaler.fit_transform(df[["y"]])
    return df, scaler

df, scaler = load_data()
df_model = df[["ds", "y_scaled"]].rename(columns={"y_scaled": "y"})

# ================= PRÉDICTIONS ================= #
def predict_neuralprophet(model, df_input, periods, scaler):
    future = model.make_future_dataframe(df_input, periods=periods)
    forecast = model.predict(future)
    forecast["yhat1"] = scaler.inverse_transform(forecast["yhat1"].values.reshape(-1, 1))
    forecast["yhat1"] = forecast["yhat1"]
    return forecast[["ds", "yhat1"]].tail(periods)

def predict_lstm(model, df_input, periods, scaler, seq_len=21):
    data = df_input["y"].values
    input_seq = data[-seq_len:]
    predictions = []

    for _ in range(periods):
        x_input = input_seq.reshape(1, seq_len, 1)
        next_val = model.predict(x_input, verbose=0)[0][0]
        predictions.append(next_val)
        input_seq = np.append(input_seq[1:], next_val)

    preds_scaled = np.array(predictions).reshape(-1, 1)
    preds = scaler.inverse_transform(preds_scaled).flatten()
    future_dates = pd.date_range(df_input["ds"].iloc[-1] + timedelta(days=1), periods=periods)
    
    return pd.DataFrame({"ds": future_dates, "yhat1": preds})

# ================= GÉNÉRATION ================= #
with st.spinner("🔮 Génération des prévisions..."):
    if model_choice == "LSTM":
        forecast_df = predict_lstm(lstm_model, df_model, n_days, scaler)
    else:
        forecast_df = predict_neuralprophet(neural_model, df_model, n_days, scaler)
tab1, tab2, tab3= st.tabs(["Visualisez les prevision", "visualisez les Metriques","Téléchargez les previsions ici"])
# ================= VISUALISATION ================= #
with tab1:
    fig = go.Figure()
    df_f = forecast_df.rename(columns={"yhat1": "y"})
    print(df_f.columns)
    df = pd.concat([df, df_f], ignore_index=True)
    print(df.tail(22))
    fig.add_trace(go.Scatter(x=df["ds"], y=df["y"], name="Historique", line=dict(color='gray')))
    fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["yhat1"], name="Prévision sur 21 jours", line=dict(color='red')))
    fig.update_layout(title=f"Prévision sur {n_days} jours avec {model_choice}",
                    xaxis_title="Date", yaxis_title="Prix ($)",
                    template="plotly_white")
    st.plotly_chart(fig, use_container_width=True)

# ================= MÉTRIQUES DÉMO ================= #
with tab2:
    st.expander("📈 Métriques de performance")
    # Fausse comparaison pour illustration
    true_fake = np.sin(np.linspace(0, 10, n_days)) + np.random.normal(0, 0.1, n_days)
    pred_fake = forecast_df["yhat1"].values

    mape = mean_absolute_percentage_error(true_fake, pred_fake)
    rmse = mean_squared_error(true_fake, pred_fake)

    col1, col2 = st.columns(2)
    with col1:
        st.metric("MAPE (Mean absolute percentage errot)", f"{mape:.2%}")
    with col2:
        st.metric("RMSE (Root Mean Squared Error)", f"{rmse:.2f}")
with tab3:
    # ================= EXPORT ================= #
    
    # ================= TABLEAU ================= #
    st.subheader("📋 Détails des des 5 premiers jours prédits")
    st.dataframe(forecast_df.head(5).style.format({"yhat1": "{:.2f}"}))
    csv = forecast_df.to_csv(index=False).encode("utf-8")
    st.download_button("⬇️ Télécharger les prévisions", csv, file_name=f"prevision_{model_choice}.csv", mime="text/csv")

    # ================= FOOTER ================= #
    st.markdown("<hr style='margin-top:40px;'>", unsafe_allow_html=True)