import streamlit as st
import pandas as pd
import seaborn as sns
import numpy as np
import pickle
import matplotlib.pyplot as plt
from data_preparation import preprocess_data,data_imp
from clustering import perform_clustering, plot_clusters,summarize_cluster_characteristics
from feature_selection import select_features_pca, select_features_rfe, select_features_rf
from sklearn.preprocessing import StandardScaler

insurance_feature_descriptions,bankng_feature_descriptions,retail_feature_descriptions,insurance_defaults,banking_defaults,retail_defaults=data_imp()

def load_data(dataset_choice):
    if dataset_choice == "Insurance":
        data = pd.read_sas('a2z_insurance.sas7bdat',encoding='latin1') 
    elif dataset_choice == "Retail":
        data = pd.read_csv('retaildata.csv', encoding='latin1') 
    elif dataset_choice == "Banking":
        data = pd.read_csv('bankingdata.csv', encoding='latin1')
    return data

# Function to display Business Understanding section
def display_business_understanding():
    st.subheader("Business Objective")
    st.write("""
    ###### Customer segmentation is a fundamental task in marketing and customer relationship management. With the advancements in data analytics and machine learning, it is now possible to group customers into distinct segments with a high degree of precision, allowing businesses to tailor their marketing strategies and offerings to each segment's unique needs and preferences.

    ###### Through this customer segmentation, businesses can achieve:
    - **Personalization**: Tailoring marketing strategies to meet the unique needs of each segment.
    - **Optimization**: Efficient allocation of marketing resources.
    - **Insight**: Gaining a deeper understanding of the customer base.
    - **Engagement**: Enhancing customer engagement and satisfaction.

    ###### => Problem/Requirement: Utilize machine learning and data analysis techniques in Python to perform customer segmentation.
    
    """)
    st.image("Customer-Segmentation.png", caption="Customer Segmentation", use_column_width=True)

# Function to display Dataset section
def display_dataset_selection():
    dataset_choice = st.selectbox("Select Dataset", ("Insurance", "Retail", "Banking"))
    data = load_data(dataset_choice)
    st.write(f"Dataset: {dataset_choice}")
    st.write("Number of rows:", data.shape[0])
    st.write("Number of columns:", data.shape[1])
    st.write("First five rows of the data:")
    st.write(data.head())
    if dataset_choice=="Insurance":
        st.write(insurance_feature_descriptions)
    elif dataset_choice=="Retail":
        st.write(retail_feature_descriptions)
    else:
        st.write(bankng_feature_descriptions)
    return data
    
# Function to display Modeling & Evaluation section
def display_modeling_evaluation():
    dataset_choice = st.selectbox("Select Dataset", ("Insurance", "Retail", "Banking"))
    data = load_data(dataset_choice)
    data = preprocess_data(data)

    # Sidebar for feature selection and clustering method
    st.sidebar.header("Feature Selection and Clustering Method")
    feature_selection_method = st.sidebar.selectbox("Select feature selection method", ('PCA', 'RFE', 'Random Forest'))
    n_clusters = st.sidebar.slider("Number of clusters", min_value=2, max_value=10, value=3)

    if feature_selection_method == 'PCA':
        n_components = st.sidebar.slider("Number of PCA components", min_value=2, max_value=10, value=5)
    elif feature_selection_method in ['RFE', 'Random Forest']:
        n_features_to_select = st.sidebar.slider("Number of features to select", min_value=2, max_value=10, value=5)
    
    # Perform clustering on button click
    if st.sidebar.button("Cluster"):
        if feature_selection_method == 'PCA':
            selected_data, selected_features = select_features_pca(data, n_components)
        elif feature_selection_method == 'RFE':
            selected_data, selected_features = select_features_rfe(data, n_features_to_select)
        elif feature_selection_method == 'Random Forest':
            selected_data, selected_features = select_features_rf(data, n_features_to_select)

        st.write(f"Selected Features: {selected_features}")
        clustered_data, score, df_value_scaled, labels, model = perform_clustering(selected_data, n_clusters)
        st.write(f"Number of Clusters: {n_clusters}")
        st.write(f"Silhouette Score: {score}")
        st.write("Clustered Data")
        st.write(clustered_data)
        st.write("Cluster Visualization")
        plot_clusters(df_value_scaled, labels)

        # Store selected features and model in session state
        st.session_state.selected_features = selected_features
        st.session_state.model = model
        st.session_state.clustered_data = clustered_data
        st.session_state.labels = labels
        st.session_state.df_value_scaled = df_value_scaled

    # Predict new data based on selected features
    if 'selected_features' in st.session_state and 'model' in st.session_state:
        st.write("### Predict Cluster")

        # Use st.form to handle input fields
        with st.form(key='prediction_form'):
            user_input = {}
            for feature in st.session_state.selected_features:
                
                # Set default values based on the dataset choice
                if dataset_choice == "Insurance":
                    default_value = insurance_defaults.get(feature, 0.0)
                elif dataset_choice == "Banking":
                    default_value = banking_defaults.get(feature, 0.0)
                elif dataset_choice == "Retail":
                    default_value = retail_defaults.get(feature, 0.0)
                else:
                    default_value = 0.0

                user_input[feature] = st.number_input(f'Enter {feature}', value=default_value)

            submit_button = st.form_submit_button(label='Predict')

            if submit_button:
                user_df = pd.DataFrame(user_input, index=[0])

                scaler = StandardScaler()
                user_df_scaled = scaler.fit_transform(user_df)

                cluster = st.session_state.model.predict(user_df_scaled)
                st.write(f'The predicted cluster for the input data is: {cluster[0]}')

                # Get the clustered data and labels from session state
                clustered_data = st.session_state.clustered_data
                labels = st.session_state.labels
                df_value_scaled = st.session_state.df_value_scaled

                # Summarize cluster characteristics
                summary = summarize_cluster_characteristics(clustered_data, labels, cluster[0])

                # Generate and display the inference
                inference = f"Based on the input features, the customer belongs to Cluster {cluster[0]}, which is characterized by the following average values:\n"
                for feature, value in summary.items():
                    inference += f"- {feature}: {value:.2f}\n"
                st.write(inference)

                plot_clusters(df_value_scaled, labels, new_data_point=user_df_scaled)


# Main app structure
def main():
    st.title("Customer Segmentation Demo")
    st.header("Customer Segmentation")

    # Sidebar menu options
    menu = ["Business Understanding", "Dataset", "Modeling & Prediction"]
    choice = st.sidebar.selectbox('Menu', menu)

    if choice == 'Business Understanding':
        display_business_understanding()

    elif choice == 'Dataset':
        display_dataset_selection()

    elif choice == 'Modeling & Prediction':
        display_modeling_evaluation()

if __name__ == "__main__":
    main()