|
import argparse |
|
import os |
|
from data_loader import load_and_process_data, CATEGORICAL_COLUMNS |
|
from model_trainer import train_models |
|
from model_manager import save_models, load_models |
|
from model_predictor import predict |
|
from config import MODEL_DIR, CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS |
|
import wandb |
|
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report |
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
|
|
def main(train=True, retrain=False): |
|
""" Main entry point to train, retrain or predict """ |
|
|
|
if not os.path.exists(MODEL_DIR): |
|
os.makedirs(MODEL_DIR) |
|
print("\nπ Loading data...") |
|
X_train, X_val, y_train, y_val, test_df = load_and_process_data() |
|
|
|
if train or retrain: |
|
print("\nπ Training models...") |
|
models = train_models(X_train, y_train, CATEGORICAL_COLUMNS) |
|
save_models(models) |
|
|
|
else: |
|
print("\nπ Loading existing models...") |
|
models = load_models() |
|
|
|
|
|
|
|
param_grid = {"CATBOOST_PARAMS": CATBOOST_PARAMS, |
|
"XGB_PARAMS": XGB_PARAMS, |
|
"RF_PARAMS": RF_PARAMS} |
|
os.getenv("WANDB_API_KEY") |
|
run = wandb.init(project="is_click_predictor", config=param_grid) |
|
|
|
print("\nπ Makings predictions for validation set...") |
|
predictions_val = predict(models, X_val) |
|
accuracy_val = accuracy_score(y_val, predictions_val["is_click_predicted"]) |
|
balanced_accuracy_val = balanced_accuracy_score(y_val, predictions_val["is_click_predicted"]) |
|
classification_report_val = classification_report(y_val, predictions_val["is_click_predicted"], output_dict=True) |
|
classification_report_val = pd.DataFrame(classification_report_val).transpose() |
|
predictions_val_table = wandb.Table(dataframe=predictions_val) |
|
classification_report_val_table = wandb.Table(dataframe=classification_report_val) |
|
|
|
print("\nπ Making predictions for test set...") |
|
predictions = predict(models, test_df) |
|
|
|
|
|
run.log({"param_grid": param_grid, |
|
"accuracy_val": accuracy_val, |
|
"balanced_accuracy_val": balanced_accuracy_val, |
|
"classification_report_val_table": classification_report_val_table, |
|
"predictions_val_table": predictions_val_table, |
|
"y_val": y_val.tolist()}) |
|
run.finish() |
|
|
|
|
|
predictions.to_csv("final_predictions.csv", index=False) |
|
print("\nβ
Predictions saved successfully as 'final_predictions.csv'!") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
main(train=True, retrain=False) |