import pandas as pd
import numpy as np
import time
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from imblearn.over_sampling import SMOTE

# ===========================
#  CONFIGURATION
# ===========================

TRAIN_PATH = "data/train_dataset_full - train_dataset_full.csv"
# TRAIN_PATH = "data/train_dataset_full - train_dataset_partial_for_testing.csv"
TEST_PATH = "data/X_test_1st.csv"  # Replace with actual test dataset path

CATEGORICAL_COLUMNS = ["gender", "product",]
IDS_COLUMNS = [ "user_id", "session_id", "campaign_id", "webpage_id"]
TARGET_COLUMN = "is_click"
FEATURE_COLUMNS = [
    "age_level", "gender", "product",
    "product_category_1", "product_category_2", "user_group_id",
    "user_depth", "city_development_index", "var_1"
]

AGGREGATED_COLUMNS = [
    "click_sum_age_sex_prod", "click_count_age_sex_prod",
    "unique_campaigns_age_sex_prod", "unique_webpages_age_sex_prod",
    "click_sum_city_age_prod", "click_count_city_age_prod",
    "unique_campaigns_city_age_prod", "unique_webpages_city_age_prod"
]

TEMPORAL_COLUMNS = ["year", "month", "day", "hour", "minute", "weekday"]
# ===========================
#  LOAD DATASETS
# ===========================

def load_data(train_path=TRAIN_PATH, test_path=TEST_PATH):
    """Load train & test datasets, handling missing values."""
    train_df = pd.read_csv(train_path)
    y_train = train_df[TARGET_COLUMN]
    train_df = train_df[~y_train.isnull()]


    test_df = pd.read_csv(test_path)

    train_df["DateTime"] = pd.to_datetime(train_df["DateTime"])
    test_df["DateTime"] = pd.to_datetime(test_df["DateTime"])
    train_df["DateTime"].fillna(train_df["DateTime"].mode()[0], inplace=True)
    test_df["DateTime"].fillna(test_df["DateTime"].mode()[0], inplace=True)

    if "DateTime" in train_df.columns:
        train_df["DateTime"] = pd.to_datetime(train_df["DateTime"])
        train_df["year"] = train_df["DateTime"].dt.year
        train_df["month"] = train_df["DateTime"].dt.month
        train_df["day"] = train_df["DateTime"].dt.day
        train_df["hour"] = train_df["DateTime"].dt.hour
        train_df["minute"] = train_df["DateTime"].dt.minute
        train_df["weekday"] = train_df["DateTime"].dt.weekday
        train_df.drop("DateTime", axis=1, inplace=True)

    if "DateTime" in test_df.columns:
        test_df["DateTime"] = pd.to_datetime(test_df["DateTime"])
        test_df["year"] = test_df["DateTime"].dt.year
        test_df["month"] = test_df["DateTime"].dt.month
        test_df["day"] = test_df["DateTime"].dt.day
        test_df["hour"] = test_df["DateTime"].dt.hour
        test_df["minute"] = test_df["DateTime"].dt.minute
        test_df["weekday"] = test_df["DateTime"].dt.weekday
        test_df.drop("DateTime", axis=1, inplace=True)

    # Fill missing values
    train_df.fillna(-1, inplace=True)
    test_df.fillna(-1, inplace=True)

    return train_df, test_df


# ===========================
#  FEATURE ENGINEERING: AGGREGATIONS
# ===========================

def add_aggregated_features(df, test_df):
    """Creates aggregated features based on age, gender, and product interactions."""

    # Aggregate by age & gender vs product
    age_sex_product_agg = df.groupby(["age_level", "gender", "product"]).agg({
        "is_click": ["sum", "count"],
        "campaign_id": "nunique",
        "webpage_id": "nunique"
    }).reset_index()

    # Rename columns after aggregation
    age_sex_product_agg.columns = ["age_level", "gender", "product",
                                   "click_sum_age_sex_prod", "click_count_age_sex_prod",
                                   "unique_campaigns_age_sex_prod", "unique_webpages_age_sex_prod"]

    # Merge into train & test datasets
    df = df.merge(age_sex_product_agg, on=["age_level", "gender", "product"], how="left")
    test_df = test_df.merge(age_sex_product_agg, on=["age_level", "gender", "product"], how="left")

    # Aggregate by city, age, product
    city_age_product_agg = df.groupby(["city_development_index", "age_level", "product"]).agg({
        "is_click": ["sum", "count"],
        "campaign_id": "nunique",
        "webpage_id": "nunique"
    }).reset_index()

    # Rename columns
    city_age_product_agg.columns = ["city_development_index", "age_level", "product",
                                    "click_sum_city_age_prod", "click_count_city_age_prod",
                                    "unique_campaigns_city_age_prod", "unique_webpages_city_age_prod"]

    # Merge into train & test datasets
    df = df.merge(city_age_product_agg, on=["city_development_index", "age_level", "product"], how="left")
    test_df = test_df.merge(city_age_product_agg, on=["city_development_index", "age_level", "product"], how="left")

    # Fill missing values after merging
    df.fillna(0, inplace=True)
    test_df.fillna(0, inplace=True)

    return df, test_df


# ===========================
#  ENCODE & NORMALIZE FEATURES
# ===========================

def preprocess_data(df, test_df, categorical_columns):
    """Encodes categorical features, normalizes numerical features, and prepares the dataset."""

    label_encoders = {}
    for col in categorical_columns:
        le = LabelEncoder()
        df[col] = le.fit_transform(df[col].astype(str))
        test_df[col] = test_df[col].astype(str).map(lambda s: le.transform([s])[0] if s in le.classes_ else -1)
        label_encoders[col] = le  # Store encoders for later use

    numerical_columns = [col for col in FEATURE_COLUMNS + AGGREGATED_COLUMNS if col not in categorical_columns]

    # scaler = StandardScaler()
    # df[numerical_columns] = scaler.fit_transform(df[numerical_columns])
    # test_df[numerical_columns] = scaler.transform(test_df[numerical_columns])


    return df, test_df, label_encoders,# scaler


# ===========================
#  SPLIT DATA & HANDLE IMBALANCE
# ===========================

def split_and_balance_data(df, target_column):
    """Splits data into training and validation sets, applies SMOTE to balance classes."""

    X = df[IDS_COLUMNS + FEATURE_COLUMNS + AGGREGATED_COLUMNS + TEMPORAL_COLUMNS]
    y = df[target_column]

    # Handle class imbalance using SMOTE
    smote = SMOTE(sampling_strategy="auto", random_state=42)
    X_resampled, y_resampled = smote.fit_resample(X, y)

    # Split into training & validation sets
    X_train, X_val, y_train, y_val = train_test_split(
        X_resampled, y_resampled, test_size=0.2, random_state=42, stratify=y_resampled
    )

    return X_train, X_val, y_train, y_val


# ===========================
#  VISUALIZE FEATURES
# ===========================

def visualize_features():
    """Generates visualizations for aggregated features."""

    df, _ = load_data()
    df, _ = add_aggregated_features(df, df)

    sns.set_style("whitegrid")

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    sns.barplot(x="age_level", y="click_sum_age_sex_prod", hue="gender",
                data=df, ax=axes[0], palette="coolwarm")
    axes[0].set_title("Total Clicks by Age & Gender vs Product")

    sns.barplot(x="city_development_index", y="click_sum_city_age_prod", hue="age_level",
                data=df, ax=axes[1], palette="viridis")
    axes[1].set_title("Total Clicks by City Development Index & Age")

    plt.tight_layout()
    plt.show()


# ===========================
#  RUN FULL DATA PROCESSING PIPELINE
# ===========================

def load_and_process_data():
    """Runs the full data processing pipeline and returns preprocessed training & test data."""

    df, test_df = load_data()
    df, test_df = add_aggregated_features(df, test_df)
    df, test_df, label_encoders = preprocess_data(df, test_df, CATEGORICAL_COLUMNS)
    X_train, X_val, y_train, y_val = split_and_balance_data(df, TARGET_COLUMN)

    return X_train, X_val, y_train, y_val, test_df


if __name__ == "__main__":
    print("🔹 Loading and processing data...")
    X_train, X_val, y_train, y_val, test_df = load_and_process_data()
    print("✅ Data successfully loaded and processed!")