import joblib | |
from catboost import CatBoostClassifier | |
from xgboost import XGBClassifier | |
from config import CATBOOST_MODEL_PATH, XGB_MODEL_PATH, RF_MODEL_PATH | |
def save_models(models): | |
""" Save trained models """ | |
models["CatBoost"].save_model(CATBOOST_MODEL_PATH) | |
if models["XGBoost"] is not None: | |
# Save XGBoost model in binary format to reduce memory usage | |
models["XGBoost"].get_booster().save_model(XGB_MODEL_PATH) | |
joblib.dump(models["RandomForest"], RF_MODEL_PATH) | |
print("✅ Models saved successfully!") | |
def load_models(): | |
""" Load trained models """ | |
catboost = CatBoostClassifier() | |
catboost.load_model(CATBOOST_MODEL_PATH) | |
xgb = XGBClassifier() # Load XGBoost model in binary format | |
xgb.load_model(XGB_MODEL_PATH) | |
rf = joblib.load(RF_MODEL_PATH) | |
return {"CatBoost": catboost, "XGBoost": xgb, "RandomForest": rf} |