bradmontierth commited on
Commit
8cd7f61
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - healthcare
5
+ - medicare
6
+ - xgboost
7
+ - logistic-regression
8
+ - length-of-stay
9
+ - readmission
10
+ - discharge-prediction
11
+ - classification
12
+ - regression
13
+ datasets:
14
+ - cms-lds
15
+
16
+ ---
17
+
18
+ # Medicare Inpatient Outcome Prediction Models
19
+
20
+ **Model ID:** Medicare LDS 2023 Inpatient Model Bundle
21
+ **Model Types:** XGBoost Regressor & Calibrated Multinomial Logistic Regression
22
+ **Dataset:** 2023 CMS Limited Data Set (LDS)
23
+ **Target Level:** Inpatient Encounter
24
+
25
+ ---
26
+
27
+ ## What the Model Predicts
28
+
29
+ This bundle contains three distinct models that predict key outcomes for a given inpatient hospital stay:
30
+
31
+ - **Length of Stay (Regression):** Predicts the total number of days for the inpatient stay.
32
+ - **Readmission Probability (Binary Classification):** Predicts the probability (from 0.0 to 1.0) that the patient will be readmitted to a hospital within 30 days of discharge.
33
+ - **Discharge Location Probability (Multiclass Classification):** Predicts the probability for each possible discharge location (e.g., Home, Skilled Nursing Facility, Hospice).
34
+
35
+ ---
36
+
37
+ ## Intended Use
38
+
39
+ This model bundle is designed to support a variety of clinical and operational workflows:
40
+
41
+ - **Discharge Planning & Care Management:** Identify high-risk patients who may need additional support or a specific type of post-acute care to prevent readmission.
42
+ - **Resource Planning:** Forecast bed-days and resource utilization based on predicted length of stay.
43
+ - **Actuarial Analysis:** Inform risk stratification and cost estimation models.
44
+ - **Benchmarking:** Compare observed outcomes against predicted risks for a given patient population.
45
+ - **Healthcare Research:** Analyze drivers of inpatient outcomes.
46
+
47
+ > **Note on Prediction Type:** The models are trained for **concurrent prediction** — they use clinical and demographic data available during the inpatient stay to predict outcomes related to that same stay.
48
+
49
+ ---
50
+
51
+ ## Model Performance
52
+
53
+ > These metrics reflect performance on a **20% test set** held out from the 2023 CMS LDS data. All values were calculated on unseen data and represent model generalization performance.
54
+
55
+ ### Model 1: Length of Stay (XGBoost Regressor)
56
+
57
+ | Target | R² | MAE (days) |
58
+ |--------------------|------|------------|
59
+ | `length_of_stay` | 0.25 | 2.72 |
60
+
61
+ ### Model 2: Readmission Probability (Calibrated Logistic Regression)
62
+
63
+ | Target | AUC ROC | Brier Score |
64
+ |--------------------------|---------|-------------|
65
+ | `readmission_probability`| 0.7483 | 0.1176 |
66
+
67
+ - **AUC ROC:** Measures the model's ability to distinguish between patients who will and will not be readmitted (higher is better).
68
+ - **Brier Score:** Measures the accuracy of the predicted probabilities (lower is better).
69
+
70
+ ### Model 3: Discharge Location (Calibrated Logistic Regression)
71
+
72
+ | Target | Accuracy | Brier Score (Macro Avg) |
73
+ |----------------------|----------|--------------------------|
74
+ | `discharge_location` | 0.5216 | 0.0771 |
75
+
76
+ - **Accuracy:** The overall percentage of times the model predicted the correct discharge location.
77
+ - **Brier Score (Macro Avg):** The average Brier Score across all possible discharge location classes (lower is better).
78
+
79
+ ---
80
+
81
+ ## Files Included
82
+
83
+ - `inpatient_models_bundle_medicare_lds_2023_fs.pkl` — A compressed pickle file containing the trained models, feature lists, and encoders. The filename may vary based on training parameters (e.g., `_fs` indicates feature selection was used). The bundle includes:
84
+ - `los_model` (XGBoost)
85
+ - `readmission_model` (Calibrated Logistic Regression)
86
+ - `discharge_model` (Calibrated Logistic Regression)
87
+ - `feature_columns_*` (specific feature lists for each model)
88
+ - `le_discharge` (label encoder for discharge location)
89
+ - `Train Tuva Concurrent Inpatient Models.ipynb` — The notebook script used for training, feature selection, and evaluation on Snowflake.
90
+ - `predict_inpatient.py` — An example prediction script for running the bundle on new data.
91
+ - `feature_fill_rate_inpatient.csv` — A diagnostic file detailing the prevalence of each feature in the training dataset.
92
+ - `inpatient_feature_importance.csv` — A file containing the calculated importance of each feature for each of the three models.
93
+
94
+ ---
95
+
96
+ ## Understanding Model Artifacts
97
+
98
+ This repository includes two key CSV files that provide insight into the model's training data and internal logic. These are generated by the training notebook, which also populates corresponding tables in Snowflake for easier querying (`FEATURE_FREQUENCY_STATS_INPATIENT` and `MODEL_FEATURE_IMPORTANCE_INPATIENT`).
99
+
100
+ ### Feature Fill Rates (`feature_fill_rate_inpatient.csv`)
101
+
102
+ This file is a diagnostic tool for understanding the input data used to train the models. It helps you check for data drift or data quality issues.
103
+
104
+ | Column | Description |
105
+ |---|---|
106
+ | `FEATURE_NAME` | The name of the input feature (e.g., `age_at_admit`, `cond_hypertension`). |
107
+ | `POSITIVE_COUNT` | The number of records in the training set where this feature was present (non-zero). |
108
+ | `TOTAL_ROWS` | The total number of records in the training set. |
109
+ | `POSITIVE_RATE_PERCENT` | The prevalence or "fill rate" of the feature (`POSITIVE_COUNT` / `TOTAL_ROWS`). |
110
+
111
+ **How to Use:** Compare the `POSITIVE_RATE_PERCENT` from this file with the rates from your own prediction input data. Significant discrepancies can point to data pipeline issues and may explain poor model performance.
112
+
113
+ ### Feature Importances (`inpatient_feature_importance.csv`)
114
+
115
+ This file provides model explainability by showing which features are most influential for each of the three models.
116
+
117
+ | Column | Description |
118
+ |---|---|
119
+ | `MODEL_NAME` | Identifies the model (e.g., `Inpatient_LOS_FeatureSelected`). |
120
+ | `FEATURE_NAME` | The name of the input feature. |
121
+ | `IMPORTANCE_VALUE` | A numeric score indicating the feature's influence. Higher is more important. |
122
+ | `IMPORTANCE_RANK` | The rank of the feature's importance for that specific model (1 is most important). |
123
+
124
+ **How to Use:** Use this file to understand the key drivers behind the model's predictions. For example, you can filter by `MODEL_NAME` for the readmission model and sort by `IMPORTANCE_RANK` to see what most influences readmission risk. This is useful for clinical validation and debugging.
125
+
126
+ ---
127
+
128
+ ## Quick Start: End-to-End Workflow
129
+
130
+ This section provides high-level instructions for running a model with the Tuva Project. The workflow involves preparing benchmark data using dbt, running a Python prediction script, and optionally ingesting the results back into dbt for analysis.
131
+
132
+ ### 1. Configure Your dbt Project
133
+
134
+ You need to enable the correct variables in your `dbt_project.yml` file to control the workflow.
135
+
136
+ #### A. Enable Benchmark Marts
137
+
138
+ These two variables control which parts of the Tuva Project are active. They are `false` by default.
139
+
140
+ ```yaml
141
+ # in dbt_project.yml
142
+ vars:
143
+ benchmarks_train: true
144
+ benchmarks_already_created: true
145
+ ```
146
+
147
+ - `benchmarks_train`: Set to `true` to build the datasets that the ML models will use for making predictions.
148
+ - `benchmarks_already_created`: Set to `true` to ingest model predictions back into the project as a new dbt source.
149
+
150
+ #### B. (Optional) Set Prediction Source Locations
151
+
152
+ If you plan to bring predictions back into dbt for analysis, you must define where dbt can find the prediction data.
153
+
154
+ ```yaml
155
+ # in dbt_project.yml
156
+ vars:
157
+ predictions_person_year: "{{ source('benchmark_output', 'person_year') }}"
158
+ predictions_inpatient: "{{ source('benchmark_output', 'inpatient') }}"
159
+ ```
160
+
161
+ #### C. Configure `sources.yml`
162
+
163
+ Ensure your `sources.yml` file includes a definition for the source you referenced above (e.g., `benchmark_output`) that points to the database and schema where your model's prediction outputs are stored.
164
+
165
+ ---
166
+
167
+ ### 2. The 3-Step Run Process
168
+
169
+ This workflow can be managed by any orchestration tool (e.g., Airflow, Prefect, Fabric Notebooks) or run manually from the command line.
170
+
171
+ #### Step 1: Generate the Training & Benchmarking Data
172
+
173
+ Run the Tuva Project with `benchmarks_train` enabled. This creates the input data required by the ML model.
174
+
175
+ ```bash
176
+ dbt build --vars '{benchmarks_train: true}'
177
+ ```
178
+
179
+ To run only the benchmark mart:
180
+
181
+ ```bash
182
+ dbt build --select tag:benchmarks_train --vars '{benchmarks_train: true}'
183
+ ```
184
+
185
+ #### Step 2: Run the Prediction Python Code
186
+
187
+ Execute the Python script `(predict inpatient.ipynb)` to generate predictions. This script will read the data created in Step 1 and write the prediction outputs to a persistent location (e.g., a table in your data warehouse).
188
+
189
+ *We have provided example Snowflake Notebook code within each model's repository that was used in Tuva's environment.*
190
+
191
+ #### Step 3: (Optional) Bring Predictions back into Tuva Project
192
+
193
+ To bring the predictions back into the Tuva Project for analysis, run dbt again with `benchmarks_already_created` enabled. This populates the analytics marts.
194
+
195
+ ```bash
196
+ dbt build --vars '{benchmarks_already_created: true, benchmarks_train: false}'
197
+ ```
198
+
199
+ To run only the analysis models:
200
+
201
+ ```bash
202
+ dbt build --select tag:benchmarks_analysis --vars '{benchmarks_already_created: true, benchmarks_train: false}'
203
+ ```
204
+
205
+ ---
Train Tuva Concurrent Inpatient Models.ipynb ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "kernelspec": {
4
+ "display_name": "Streamlit Notebook",
5
+ "name": "streamlit"
6
+ },
7
+ "lastEditStatus": {
8
+ "notebookId": "6rovstl42ft2p5id6gwo",
9
+ "authorId": "374530764978",
10
+ "authorName": "BRAD",
11
+ "authorEmail": "[email protected]",
12
+ "sessionId": "65561efa-4d18-4072-8f4d-10240cb902ba",
13
+ "lastEditTime": 1750870004305
14
+ }
15
+ },
16
+ "nbformat_minor": 5,
17
+ "nbformat": 4,
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "id": "3775908f-ca36-4846-8f38-5adca39217f2",
22
+ "metadata": {
23
+ "language": "python",
24
+ "name": "cell1"
25
+ },
26
+ "source": "0]}\")\n \n base_readmit_model = LogisticRegression(class_weight='balanced', random_state=42, max_iter=1000, solver='liblinear')\n \n # Determine the feature set to use.\n if FAST_MODE:\n print(\"\\n[FAST MODE] Skipping feature selection. Using all available features.\")\n best_readmission_features = X_train_base.columns.tolist()\n else:\n best_readmission_features = find_best_feature_subset(\n model=base_readmit_model, X_train=X_train_base, y_train=y_train_base, X_val=X_calib_read, y_val=y_calib_read,\n scoring_func=roc_auc_score, higher_is_better=True, model_name=\"Readmission (Logistic Regression)\"\n )\n\n print(f\"\\nTraining final Readmission model pipeline using {len(best_readmission_features)} features...\")\n base_model_for_calib = clone(base_readmit_model)\n base_model_for_calib.fit(X_train_base[best_readmission_features], y_train_base)\n \n # Log feature importances from the base (uncalibrated) model.\n uncal_read_model_name = f\"Inpatient_Readmission_Base_Uncalibrated_{MODEL_NAME_SUFFIX}{EXCLUSION_SUFFIX}\"\n log_feature_importances_to_snowflake(session, base_model_for_calib, best_readmission_features, MODEL_RUN_ID, uncal_read_model_name, TARGET_READMISSION, FEATURE_IMPORTANCE_TABLE_NAME)\n \n # Evaluate and log metrics for the uncalibrated model for comparison.\n y_pred_proba_uncal = base_model_for_calib.predict_proba(X_test_read[best_readmission_features])[:, 1]\n uncalibrated_metrics = calculate_binary_classification_proba_metrics(y_test_read, y_pred_proba_uncal)\n log_model_metrics_to_snowflake(session, MODEL_RUN_ID, uncal_read_model_name, TARGET_READMISSION + \"_Probability\", uncalibrated_metrics, \"Binary_Uncalibrated\", METRICS_TABLE_NAME, MODEL_SOURCE_TAG, MODEL_YEAR_TAG)\n \n # Calibrate the model on the held-out calibration set.\n calibrated_readmission_model = CalibratedClassifierCV(base_model_for_calib, method='isotonic', cv='prefit')\n calibrated_readmission_model.fit(X_calib_read[best_readmission_features], y_calib_read)\n y_pred_proba_cal = calibrated_readmission_model.predict_proba(X_test_read[best_readmission_features])[:, 1]\n\n print(\"\\nCalibrated Readmission Model - Test Set Evaluation:\")\n calibrated_proba_metrics = calculate_binary_classification_proba_metrics(y_test_read, y_pred_proba_cal)\n for k, v in calibrated_proba_metrics.items(): print(f\" {k}: {v:.4f}\")\n \n cal_read_model_name = f\"Inpatient_Readmission_Calibrated_{MODEL_NAME_SUFFIX}{EXCLUSION_SUFFIX}\"\n log_model_metrics_to_snowflake(session, MODEL_RUN_ID, cal_read_model_name, TARGET_READMISSION + \"_Probability\", calibrated_proba_metrics, \"Binary_Calibrated\", METRICS_TABLE_NAME, MODEL_SOURCE_TAG, MODEL_YEAR_TAG)\n\n# --- 4.3 Model 3: Predicting Discharge Location (Multiclass Classification) ---\nprint(\"\\n\" + \"=\"*80)\nprint(\"--- Training Model 3: Calibrated Discharge Location ---\")\nprint(\"=\"*80)\nTARGET_DISCHARGE = 'discharge_location'\ncalibrated_discharge_model, le_discharge, best_discharge_features = None, None, None\n\nif TARGET_DISCHARGE not in df_pd.columns:\n print(f\"Error: Target column '{TARGET_DISCHARGE}' not found. Skipping Discharge Location model.\")\nelse:\n le_discharge = LabelEncoder()\n y_discharge_encoded = le_discharge.fit_transform(df_pd[TARGET_DISCHARGE])\n num_classes_discharge = len(le_discharge.classes_)\n print(f\"Discharge Location: {num_classes_discharge} classes found: {le_discharge.classes_}\")\n \n # Split data: 60% base train, 20% calibration, 20% test\n stratify_discharge = y_discharge_encoded if num_classes_discharge > 1 else None\n X_train_full_disc, X_test_disc, y_train_full_disc_enc, y_test_disc_enc = train_test_split(X, y_discharge_encoded, test_size=0.2, random_state=42, stratify=stratify_discharge)\n X_train_base_disc, X_calib_disc, y_train_base_disc_enc, y_calib_disc_enc = train_test_split(X_train_full_disc, y_train_full_disc_enc, test_size=0.25, random_state=42, stratify=y_train_full_disc_enc if num_classes_discharge > 1 else None)\n print(f\"Data split for discharge: Base train: {X_train_base_disc.shape[0]}, Calibration: {X_calib_disc.shape[0]}, Test: {X_test_disc.shape[0]}\")\n \n base_discharge_model = LogisticRegression(random_state=42, max_iter=1000, solver='lbfgs', multi_class='multinomial', class_weight='balanced')\n \n # Determine the feature set to use.\n if FAST_MODE:\n print(\"\\n[FAST MODE] Skipping feature selection. Using all available features.\")\n best_discharge_features = X_train_base_disc.columns.tolist()\n else:\n best_discharge_features = find_best_feature_subset(\n model=base_discharge_model, X_train=X_train_base_disc, y_train=y_train_base_disc_enc, X_val=X_calib_disc, y_val=y_calib_disc_enc,\n scoring_func=log_loss, higher_is_better=False, model_name=\"Discharge Location (Multinomial Regression)\"\n )\n\n print(f\"\\nTraining final Discharge Location model pipeline using {len(best_discharge_features)} features...\")\n base_model_for_calib_disc = clone(base_discharge_model)\n base_model_for_calib_disc.fit(X_train_base_disc[best_discharge_features], y_train_base_disc_enc)\n \n discharge_model_name = f\"Inpatient_Discharge_Cal_Overall_{MODEL_NAME_SUFFIX}{EXCLUSION_SUFFIX}\"\n log_feature_importances_to_snowflake(session, base_model_for_calib_disc, best_discharge_features, MODEL_RUN_ID, discharge_model_name, TARGET_DISCHARGE, FEATURE_IMPORTANCE_TABLE_NAME)\n \n # Calibrate the model. 'sigmoid' is used for one-vs-rest calibration, suitable for multiclass.\n calibrated_discharge_model = CalibratedClassifierCV(base_model_for_calib_disc, method='sigmoid', cv='prefit')\n calibrated_discharge_model.fit(X_calib_disc[best_discharge_features], y_calib_disc_enc)\n y_pred_proba_discharge_calibrated = calibrated_discharge_model.predict_proba(X_test_disc[best_discharge_features])\n y_pred_labels_discharge_calibrated = calibrated_discharge_model.predict(X_test_disc[best_discharge_features])\n \n print(\"\\nCalibrated Discharge Model - Test Set Evaluation:\")\n calibrated_disc_metrics = calculate_multiclass_classification_metrics(y_test_disc_enc, y_pred_labels_discharge_calibrated, y_pred_proba_discharge_calibrated, le_discharge.classes_)\n \n # Log the overall multiclass metrics.\n overall_cal_metrics_to_log = {k: v for k, v in calibrated_disc_metrics.items() if k != 'per_class_details'}\n overall_cal_metrics_to_log['BRIER_SCORE'] = calibrated_disc_metrics.get('BRIER_SCORE_MACRO_AVG')\n log_model_metrics_to_snowflake(session, MODEL_RUN_ID, discharge_model_name, TARGET_DISCHARGE, overall_cal_metrics_to_log, \"Multiclass_Cal_Overall\", METRICS_TABLE_NAME, MODEL_SOURCE_TAG, MODEL_YEAR_TAG)\n \n # --- FIX: Log the per-class metrics by mapping keys correctly ---\n discharge_class_model_name = f\"Inpatient_Discharge_Cal_Class_{MODEL_NAME_SUFFIX}{EXCLUSION_SUFFIX}\"\n for class_detail in calibrated_disc_metrics.get('per_class_details', []):\n # Create a new dict with keys the logging function expects.\n per_class_metrics_to_log = {\n 'BRIER_SCORE': class_detail.get('brier_score'),\n 'AVG_Y_PRED': class_detail.get('avg_pred_proba'),\n 'AVG_Y_TRUE': class_detail.get('true_proportion'),\n 'PRED_RATIO': class_detail.get('proba_ratio'),\n }\n log_model_metrics_to_snowflake(\n session, MODEL_RUN_ID, discharge_class_model_name,\n f\"{TARGET_DISCHARGE}_Class_{class_detail['class_name']}\",\n per_class_metrics_to_log, # Use the correctly mapped dictionary\n \"Multiclass_Cal_ClassDetail\", METRICS_TABLE_NAME, MODEL_SOURCE_TAG, MODEL_YEAR_TAG\n )\n\n print(\"\\nCalibrated Classification Report:\\n\", classification_report(y_test_disc_enc, y_pred_labels_discharge_calibrated, target_names=le_discharge.classes_.astype(str), zero_division=0, digits=4))\n\n\n# =============================================================================\n# 5. MODEL SAVING\n# =============================================================================\nprint(\"\\n\" + \"=\"*80)\nprint(\"--- Saving Models and Artifacts ---\")\nprint(\"=\"*80)\n\n# Bundle all necessary objects for deployment into a single dictionary.\ninpatient_models_bundle = {\n 'los_model': los_model,\n 'readmission_model': calibrated_readmission_model,\n 'discharge_model': calibrated_discharge_model,\n 'feature_columns_los': best_los_features,\n 'feature_columns_readmission': best_readmission_features,\n 'feature_columns_discharge': best_discharge_features,\n 'le_discharge': le_discharge,\n 'model_run_id': MODEL_RUN_ID,\n 'fast_mode': FAST_MODE,\n 'excluded_feature_prefixes': EXCLUDE_FEATURE_PREFIXES\n}\n\n# Create a descriptive file name for the bundle.\nBUNDLE_SUFFIX = \"fast\" if FAST_MODE else \"fs\"\nEXCLUSION_FILE_TAG = f\"_excl_{'-'.join([p.strip('_').lower() for p in EXCLUDE_FEATURE_PREFIXES])}\" if EXCLUDE_FEATURE_PREFIXES else \"\"\nBUNDLE_FILE_NAME = f'inpatient_models_bundle_{MODEL_SOURCE_TAG}_{MODEL_YEAR_TAG}_{BUNDLE_SUFFIX}{EXCLUSION_FILE_TAG}.pkl'\n\n# Save the bundle locally using pickle.\nwith open(BUNDLE_FILE_NAME, 'wb') as f:\n pickle.dump(inpatient_models_bundle, f)\nprint(f\"Models bundled and saved locally to: {BUNDLE_FILE_NAME}\")\n\n# Upload the local bundle file to the specified Snowflake stage.\nput_result = session.file.put(BUNDLE_FILE_NAME, SNOWFLAKE_STAGE_NAME, overwrite=True)\nif put_result[0].status == 'UPLOADED':\n print(f\"Model bundle successfully uploaded to Snowflake stage: {SNOWFLAKE_STAGE_NAME}\")\nelse:\n print(f\"Error uploading model bundle. Status: {put_result[0].status}, Message: {put_result[0].message}\")\n\nfile_size_mb = os.path.getsize(BUNDLE_FILE_NAME) / (1024 * 1024)\nprint(f\"Saved local bundle file size: {file_size_mb:.2f} MB\")\n\nprint(f\"\\n✅ Script finished ({'FAST MODE' if FAST_MODE else 'FULL MODE'}).\")",
27
+ "execution_count": null,
28
+ "outputs": []
29
+ }
30
+ ]
31
+ }
feature_fill_rate_inpatient.csv ADDED
The diff for this file is too large to render. See raw diff
 
inpatient_feature_importance.csv ADDED
The diff for this file is too large to render. See raw diff
 
inpatient_models_bundle_medicare_lds_2023_fs.pkl.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afb9b33cac24e91958adfc4f52ed3a1fd617a51c12bbf9d8be2c37eab4eb979c
3
+ size 1808823
inpatient_models_eval_metrics.csv ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL_RUN_ID,MODEL_NAME,TARGET_NAME,R2,MAE,MSE,PRED_RATIO,MAE_PERCENT,AUC_ROC,AUC_PR,LOG_LOSS,BRIER_SCORE,ACCURACY,AVG_Y_PRED,AVG_Y_TRUE,MODEL_SOURCE,MODEL_TYPE,MODEL_YEAR,EVAL_TS
2
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_transfer/other facility,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:12:05.239
3
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_snf,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:12:01.767
4
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_other,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:58.290
5
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_ipt rehab,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:55.442
6
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_hospice,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:52.653
7
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_home health,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:49.518
8
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_home,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:46.588
9
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_expired,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:43.545
10
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Overall_FeatureSelected,discharge_location,,,,,,,,1.302877,0.077095,0.521647,,,medicare_lds,Multiclass_Cal_Overall,2023,2025-06-18 21:11:40.194
11
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Readmission_Calibrated_FeatureSelected,readmission_numerator_Probability,,,,,,0.74825,0.334959,0.380948,0.117589,,0.157001,0.15629,medicare_lds,Binary_Classification_Probability_Calibrated,2023,2025-06-18 20:48:33.051
12
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Readmission_Base_Uncalibrated_FeatureSelected,readmission_numerator_Probability,,,,,,0.748529,0.341185,0.594258,0.20216,,0.428988,0.15629,medicare_lds,Binary_Classification_Probability_Uncalibrated,2023,2025-06-18 20:48:29.072
13
+ 03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_LOS_FeatureSelected,length_of_stay,0.245662,2.724616,23.875849,0.970435,54.179359,,,,,,4.880204201,5.028881,medicare_lds,Regression,2023,2025-06-18 20:46:30.898
predict inpatient.ipynb ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3775908f-ca36-4846-8f38-5adca39217f2",
7
+ "metadata": {
8
+ "language": "python",
9
+ "name": "cell1"
10
+ },
11
+ "outputs": [],
12
+ "source": [
13
+ "\"\"\"\n",
14
+ "Snowflake Inpatient Prediction and Evaluation Script (Self-Contained)\n",
15
+ "\n",
16
+ "This script is designed to run entirely within a Snowflake environment with NO\n",
17
+ "external network access. It performs the following operations:\n",
18
+ "\n",
19
+ "1. Connects to an active Snowflake session.\n",
20
+ "2. Loads a pre-trained model bundle and reference tables from a Snowflake stage.\n",
21
+ "3. Loads new inpatient data from a Snowflake table for prediction.\n",
22
+ "4. Calculates and compares the feature fill rates of the input data against the training data.\n",
23
+ "5. Performs data preprocessing, including one-hot encoding.\n",
24
+ "6. Generates predictions for Length of Stay, Readmission, and Discharge Location.\n",
25
+ "7. Calculates evaluation metrics by comparing predictions to actual outcomes.\n",
26
+ "8. Saves predictions, metrics, and diagnostic tables back to Snowflake.\n",
27
+ "\"\"\"\n",
28
+ "\n",
29
+ "import os\n",
30
+ "import gzip\n",
31
+ "import pickle\n",
32
+ "import uuid\n",
33
+ "from datetime import datetime\n",
34
+ "\n",
35
+ "import numpy as np\n",
36
+ "import pandas as pd\n",
37
+ "from snowflake.snowpark.context import get_active_session\n",
38
+ "from snowflake.snowpark.session import Session\n",
39
+ "from snowflake.snowpark.exceptions import SnowparkClientException\n",
40
+ "\n",
41
+ "from sklearn.metrics import (\n",
42
+ " r2_score, mean_absolute_error, mean_squared_error, accuracy_score,\n",
43
+ " roc_auc_score, log_loss, brier_score_loss, average_precision_score\n",
44
+ ")\n",
45
+ "from snowflake.snowpark.types import (\n",
46
+ " StructType, StructField, StringType, TimestampType, FloatType, LongType\n",
47
+ ")\n",
48
+ "\n",
49
+ "# =============================================================================\n",
50
+ "# 0. CONFIGURATION\n",
51
+ "# =============================================================================\n",
52
+ "\n",
53
+ "# --- Snowflake Environment Settings ---\n",
54
+ "DATABASE = \"CMS_SYNTHETIC\"\n",
55
+ "SCHEMA = \"BENCHMARKS\"\n",
56
+ "STAGE_DATABASE = \"CMS_SYNTHETIC\"\n",
57
+ "\n",
58
+ "# --- Input & Output Table Names ---\n",
59
+ "INPUT_TABLE_NAME = \"BENCHMARKS_INPATIENT_INPUT\"\n",
60
+ "OUTPUT_PREDICTIONS_TABLE_NAME = \"INPATIENT_PREDICTIONS\"\n",
61
+ "OUTPUT_METRICS_TABLE_NAME = \"INPATIENT_EVALUATION_METRICS\"\n",
62
+ "OUTPUT_FILL_RATE_COMPARISON_TABLE_NAME = \"INPATIENT_FILL_RATE_COMPARISON\"\n",
63
+ "\n",
64
+ "# --- Model & Artifact Loading Configuration ---\n",
65
+ "MODEL_STAGE_NAME = f\"@{STAGE_DATABASE}.{SCHEMA}.BENCHMARK_STAGE\"\n",
66
+ "MODEL_FILE_NAME_IN_STAGE = \"inpatient_models_bundle_medicare_lds_2023_fs.pkl.gz\"\n",
67
+ "FILL_RATE_FILE_NAME_IN_STAGE = \"feature_fill_rate_inpatient.csv.gz\"\n",
68
+ "\n",
69
+ "# --- Local Directory for Artifacts ---\n",
70
+ "# This directory is used within the Snowflake virtual environment to store downloaded files.\n",
71
+ "LOCAL_ARTIFACT_DIR = \"/tmp/appRoot\"\n",
72
+ "\n",
73
+ "# --- Run Configuration ---\n",
74
+ "ROW_LIMIT = None # Set to an integer to limit input rows for testing, or None for all rows.\n",
75
+ "RUN_ID = str(uuid.uuid4())\n",
76
+ "\n",
77
+ "# --- Feature Configuration ---\n",
78
+ "# Set to True to calculate and save a comparison of feature fill rates\n",
79
+ "# between the input data and the original training data. This requires\n",
80
+ "# the fill rate file to be available in the stage. Set to False to skip.\n",
81
+ "ENABLE_FILL_RATE_COMPARISON = True\n",
82
+ "\n",
83
+ "# --- Construct Full Object Names ---\n",
84
+ "FULL_INPUT_TABLE = f\"{INPUT_TABLE_NAME}\"\n",
85
+ "FULL_OUTPUT_PREDICTIONS_TABLE = f\"{OUTPUT_PREDICTIONS_TABLE_NAME}\"\n",
86
+ "FULL_METRICS_TABLE = f\"{OUTPUT_METRICS_TABLE_NAME}\"\n",
87
+ "FULL_OUTPUT_FILL_RATE_COMPARISON_TABLE = f\"{OUTPUT_FILL_RATE_COMPARISON_TABLE_NAME}\"\n",
88
+ "MODEL_STAGE_PATH = f\"{MODEL_STAGE_NAME}/{MODEL_FILE_NAME_IN_STAGE}\"\n",
89
+ "FILL_RATE_STAGE_PATH = f\"{MODEL_STAGE_NAME}/{FILL_RATE_FILE_NAME_IN_STAGE}\"\n",
90
+ "\n",
91
+ "\n",
92
+ "# =============================================================================\n",
93
+ "# 1. UTILITY FUNCTIONS\n",
94
+ "# =============================================================================\n",
95
+ "\n",
96
+ "def calculate_regression_metrics(y_true, y_pred):\n",
97
+ " \"\"\"Calculates a dictionary of standard regression metrics.\"\"\"\n",
98
+ " y_true_np, y_pred_np = np.array(y_true), np.array(y_pred)\n",
99
+ " sum_y_true, mean_y_true = np.sum(y_true_np), np.mean(y_true_np)\n",
100
+ " pred_ratio = np.sum(y_pred_np) / sum_y_true if sum_y_true != 0 else np.nan\n",
101
+ " mae_percent = (mean_absolute_error(y_true_np, y_pred_np) / mean_y_true) * 100 if mean_y_true != 0 else np.nan\n",
102
+ " return {\n",
103
+ " 'R2': r2_score(y_true_np, y_pred_np), 'MAE': mean_absolute_error(y_true_np, y_pred_np),\n",
104
+ " 'MSE': mean_squared_error(y_true_np, y_pred_np), 'PRED_RATIO': pred_ratio, 'MAE_PERCENT': mae_percent,\n",
105
+ " 'AVG_Y_PRED': np.mean(y_pred_np), 'AVG_Y_TRUE': mean_y_true\n",
106
+ " }\n",
107
+ "\n",
108
+ "def calculate_binary_classification_proba_metrics(y_true, y_pred_proba):\n",
109
+ " \"\"\"Calculates a dictionary of standard binary classification metrics from probabilities.\"\"\"\n",
110
+ " y_true_np, y_pred_proba_np = np.array(y_true), np.array(y_pred_proba)\n",
111
+ " is_multiclass = len(np.unique(y_true_np)) > 1\n",
112
+ " auc_roc = roc_auc_score(y_true_np, y_pred_proba_np) if is_multiclass else np.nan\n",
113
+ " auc_pr = average_precision_score(y_true_np, y_pred_proba_np) if is_multiclass else np.nan\n",
114
+ " return {\n",
115
+ " 'AUC_ROC': auc_roc, 'AUC_PR': auc_pr, 'LOG_LOSS': log_loss(y_true_np, y_pred_proba_np),\n",
116
+ " 'BRIER_SCORE': brier_score_loss(y_true_np, y_pred_proba_np),\n",
117
+ " 'AVG_Y_PRED_PROBA': np.mean(y_pred_proba_np), 'AVG_Y_TRUE': np.mean(y_true_np)\n",
118
+ " }\n",
119
+ "\n",
120
+ "def calculate_multiclass_classification_metrics(y_true_encoded, y_pred_labels, y_pred_proba, le_classes):\n",
121
+ " \"\"\"Calculates a dictionary of standard multi-class classification metrics.\"\"\"\n",
122
+ " num_samples, num_classes = len(y_true_encoded), len(le_classes)\n",
123
+ " metrics = {\n",
124
+ " 'ACCURACY': accuracy_score(y_true_encoded, y_pred_labels),\n",
125
+ " 'LOG_LOSS': log_loss(y_true_encoded, y_pred_proba, labels=np.arange(num_classes))\n",
126
+ " }\n",
127
+ " per_class_details, all_brier_scores = [], []\n",
128
+ " if num_samples > 0 and num_classes > 0:\n",
129
+ " for i in range(num_classes):\n",
130
+ " class_name, true_class_binary = le_classes[i], (y_true_encoded == i).astype(int)\n",
131
+ " pred_proba_for_class, true_proportion_class = y_pred_proba[:, i], np.mean(true_class_binary)\n",
132
+ " brier_score_class = brier_score_loss(true_class_binary, pred_proba_for_class) if len(np.unique(true_class_binary)) > 1 else np.nan\n",
133
+ " all_brier_scores.append(brier_score_class)\n",
134
+ " per_class_details.append({\n",
135
+ " \"class_name\": class_name, \"avg_pred_proba\": np.mean(pred_proba_for_class),\n",
136
+ " \"true_proportion\": true_proportion_class,\n",
137
+ " \"proba_ratio\": np.mean(pred_proba_for_class) / true_proportion_class if true_proportion_class > 0 else np.nan,\n",
138
+ " \"brier_score\": brier_score_class\n",
139
+ " })\n",
140
+ " metrics['per_class_details'] = per_class_details\n",
141
+ " valid_brier_scores = [s for s in all_brier_scores if not np.isnan(s)]\n",
142
+ " metrics['BRIER_SCORE_MACRO_AVG'] = np.mean(valid_brier_scores) if valid_brier_scores else np.nan\n",
143
+ " return metrics\n",
144
+ "\n",
145
+ "def log_metrics_to_snowflake(session, run_id, model_id, data_source, year_nbr, model_target, metrics_dict, table_name):\n",
146
+ " \"\"\"Constructs a DataFrame from a metrics dictionary and appends it to a Snowflake table.\"\"\"\n",
147
+ " metrics_schema = StructType([\n",
148
+ " StructField(\"RUN_ID\", StringType(), nullable=False), StructField(\"MODEL_ID\", StringType(), nullable=True),\n",
149
+ " StructField(\"DATA_SOURCE\", StringType(), nullable=True), StructField(\"YEAR_NBR\", LongType(), nullable=True),\n",
150
+ " StructField(\"MODEL_TARGET\", StringType(), nullable=True), StructField(\"EVAL_TIMESTAMP\", TimestampType(), nullable=False),\n",
151
+ " StructField(\"N_SAMPLES\", LongType(), nullable=True), StructField(\"R2\", FloatType(), nullable=True),\n",
152
+ " StructField(\"MAE\", FloatType(), nullable=True), StructField(\"MSE\", FloatType(), nullable=True),\n",
153
+ " StructField(\"PRED_RATIO\", FloatType(), nullable=True), StructField(\"MAE_PERCENT\", FloatType(), nullable=True),\n",
154
+ " StructField(\"AUC_ROC\", FloatType(), nullable=True), StructField(\"AUC_PR\", FloatType(), nullable=True),\n",
155
+ " StructField(\"LOG_LOSS\", FloatType(), nullable=True), StructField(\"BRIER_SCORE\", FloatType(), nullable=True),\n",
156
+ " StructField(\"ACCURACY\", FloatType(), nullable=True), StructField(\"AVG_Y_PRED\", FloatType(), nullable=True),\n",
157
+ " StructField(\"AVG_Y_TRUE\", FloatType(), nullable=True)\n",
158
+ " ])\n",
159
+ " avg_y_pred = metrics_dict.get('AVG_Y_PRED', metrics_dict.get('AVG_Y_PRED_PROBA'))\n",
160
+ " payload = {\n",
161
+ " \"RUN_ID\": run_id, \"MODEL_ID\": model_id, \"DATA_SOURCE\": data_source,\n",
162
+ " \"YEAR_NBR\": int(year_nbr) if pd.notna(year_nbr) else None, \"MODEL_TARGET\": model_target,\n",
163
+ " \"EVAL_TIMESTAMP\": datetime.utcnow(), \"N_SAMPLES\": metrics_dict.get('n_samples'), \"R2\": metrics_dict.get('R2'),\n",
164
+ " \"MAE\": metrics_dict.get('MAE'), \"MSE\": metrics_dict.get('MSE'), \"PRED_RATIO\": metrics_dict.get('PRED_RATIO'),\n",
165
+ " \"MAE_PERCENT\": metrics_dict.get('MAE_PERCENT'), \"AUC_ROC\": metrics_dict.get('AUC_ROC'),\n",
166
+ " \"AUC_PR\": metrics_dict.get('AUC_PR'), \"LOG_LOSS\": metrics_dict.get('LOG_LOSS'),\n",
167
+ " \"BRIER_SCORE\": metrics_dict.get('BRIER_SCORE', metrics_dict.get('BRIER_SCORE_MACRO_AVG')),\n",
168
+ " \"ACCURACY\": metrics_dict.get('ACCURACY'), \"AVG_Y_PRED\": avg_y_pred, \"AVG_Y_TRUE\": metrics_dict.get('AVG_Y_TRUE')\n",
169
+ " }\n",
170
+ " dfm = pd.DataFrame([payload]).replace({np.nan: None, pd.NaT: None})\n",
171
+ " try:\n",
172
+ " column_order = [field.name for field in metrics_schema.fields]\n",
173
+ " dfm_reordered = dfm[column_order]\n",
174
+ " snowpark_df = session.create_dataframe(dfm_reordered.values.tolist(), schema=metrics_schema)\n",
175
+ " snowpark_df.write.mode(\"append\").save_as_table(table_name)\n",
176
+ " print(f\"✅ Logged metrics for {model_target} ({data_source}, {year_nbr}) to {table_name}.\")\n",
177
+ " except Exception as e:\n",
178
+ " print(f\"Error logging metrics for {model_target}: {e}\\nPayload: {dfm.to_dict('records')}\")\n",
179
+ "\n",
180
+ "# =============================================================================\n",
181
+ "# 2. MAIN EXECUTION\n",
182
+ "# =============================================================================\n",
183
+ "def main(session: Session):\n",
184
+ " print(f\"--- Starting Inpatient Prediction Pipeline ---\\nRun ID: {RUN_ID}\")\n",
185
+ " session.use_database(DATABASE)\n",
186
+ " session.use_schema(SCHEMA)\n",
187
+ " print(f\"Session context set to DATABASE: {DATABASE}, SCHEMA: {SCHEMA}\")\n",
188
+ "\n",
189
+ " # --- Stage 1: Load Model & Artifacts ---\n",
190
+ " print(f\"\\n--- Stage 1: Loading artifacts from stage to {LOCAL_ARTIFACT_DIR} ---\")\n",
191
+ " os.makedirs(LOCAL_ARTIFACT_DIR, exist_ok=True)\n",
192
+ " df_training_fill_rates = pd.DataFrame() # Initialize as empty DataFrame\n",
193
+ "\n",
194
+ " try:\n",
195
+ " # The model bundle is a required artifact for the script to run.\n",
196
+ " session.file.get(MODEL_STAGE_PATH, LOCAL_ARTIFACT_DIR)\n",
197
+ " local_model_path = os.path.join(LOCAL_ARTIFACT_DIR, MODEL_FILE_NAME_IN_STAGE)\n",
198
+ " with gzip.open(local_model_path, \"rb\") as f:\n",
199
+ " models_bundle = pickle.load(f)\n",
200
+ " print(\"✅ Model bundle loaded successfully.\")\n",
201
+ "\n",
202
+ " # The fill rate file is optional. It is loaded only if the feature is enabled.\n",
203
+ " if ENABLE_FILL_RATE_COMPARISON:\n",
204
+ " try:\n",
205
+ " session.file.get(FILL_RATE_STAGE_PATH, LOCAL_ARTIFACT_DIR)\n",
206
+ " local_fr_path = os.path.join(LOCAL_ARTIFACT_DIR, FILL_RATE_FILE_NAME_IN_STAGE)\n",
207
+ " with gzip.open(local_fr_path, \"rb\") as f:\n",
208
+ " df_training_fill_rates = pd.read_csv(f)\n",
209
+ " print(\"✅ Training fill rates loaded successfully.\")\n",
210
+ " except Exception as e:\n",
211
+ " print(f\"WARNING: Could not load fill rate file from '{FILL_RATE_STAGE_PATH}'. \"\n",
212
+ " f\"Fill rate comparison will be skipped. Error: {e}\")\n",
213
+ "\n",
214
+ " except Exception as e:\n",
215
+ " print(f\"CRITICAL ERROR: Failed to load the required model bundle from stage. Error: {e}\")\n",
216
+ " return\n",
217
+ "\n",
218
+ " los_model, readmission_model, discharge_model = models_bundle['los_model'], models_bundle['readmission_model'], models_bundle['discharge_model']\n",
219
+ " los_features, readmission_features, discharge_features = models_bundle['feature_columns_los'], models_bundle['feature_columns_readmission'], models_bundle['feature_columns_discharge']\n",
220
+ " le_discharge = models_bundle.get('le_discharge')\n",
221
+ " model_version = models_bundle.get('model_run_id', MODEL_FILE_NAME_IN_STAGE)\n",
222
+ " print(f\" - Using Model Version (from training run): {model_version}\")\n",
223
+ "\n",
224
+ " # --- Stage 2: Load New Data for Prediction ---\n",
225
+ " print(f\"\\n--- Stage 2: Loading data from table: {FULL_INPUT_TABLE} ---\")\n",
226
+ " try:\n",
227
+ " query = session.table(FULL_INPUT_TABLE)\n",
228
+ " if ROW_LIMIT:\n",
229
+ " query = query.limit(ROW_LIMIT)\n",
230
+ " df_new_data_pd = query.to_pandas()\n",
231
+ " print(f\"Loaded {len(df_new_data_pd)} rows for prediction.\")\n",
232
+ " if df_new_data_pd.empty:\n",
233
+ " print(\"WARNING: Input data is empty. Exiting script.\")\n",
234
+ " return\n",
235
+ " df_new_data_pd.columns = [col.upper() for col in df_new_data_pd.columns]\n",
236
+ " except Exception as e:\n",
237
+ " print(f\"CRITICAL ERROR: Failed to load data from {FULL_INPUT_TABLE}. Error: {e}\")\n",
238
+ " return\n",
239
+ " \n",
240
+ " # --- Stage 3: Preprocess New Data ---\n",
241
+ " print(\"\\n--- Stage 3: Preprocessing data ---\")\n",
242
+ " df_new_data_pd_lower = df_new_data_pd.copy()\n",
243
+ " df_new_data_pd_lower.columns = df_new_data_pd_lower.columns.str.lower()\n",
244
+ " categorical_cols = ['sex', 'state', 'race', 'ms_drg_code', 'ccsr_cat']\n",
245
+ " cols_to_encode = [col for col in categorical_cols if col in df_new_data_pd_lower.columns]\n",
246
+ " if cols_to_encode:\n",
247
+ " df_new_data_encoded = pd.get_dummies(df_new_data_pd_lower, columns=cols_to_encode, dummy_na=False)\n",
248
+ " else:\n",
249
+ " df_new_data_encoded = df_new_data_pd_lower.copy()\n",
250
+ " print(f\"Data preprocessed. Total features after encoding: {len(df_new_data_encoded.columns)}\")\n",
251
+ "\n",
252
+ " # --- Stage 4: Calculate & Compare Feature Fill Rates ---\n",
253
+ " print(\"\\n--- Stage 4: Comparing input data fill rate to training data fill rate ---\")\n",
254
+ " if ENABLE_FILL_RATE_COMPARISON:\n",
255
+ " # This block executes only if the feature is enabled in the configuration.\n",
256
+ " if not df_training_fill_rates.empty:\n",
257
+ " # Calculate the non-null rate for all columns.\n",
258
+ " input_fill_rate_series = (df_new_data_encoded.notna().sum() / len(df_new_data_encoded))\n",
259
+ "\n",
260
+ " # Identify binary/categorical features to calculate positive rate instead of non-null rate.\n",
261
+ " binary_prefixes = ('hcc_', 'cms_', 'cond_') + tuple(f'{col}_' for col in cols_to_encode)\n",
262
+ " binary_feature_names = [\n",
263
+ " col for col in df_new_data_encoded.columns \n",
264
+ " if col.lower().startswith(binary_prefixes)\n",
265
+ " ]\n",
266
+ " print(f\"Identified {len(binary_feature_names)} binary-like features for positive rate calculation.\")\n",
267
+ " \n",
268
+ " # For these binary columns, calculate the positive rate (mean) and update the series.\n",
269
+ " if binary_feature_names:\n",
270
+ " existing_binary_features = [c for c in binary_feature_names if c in df_new_data_encoded.columns]\n",
271
+ " if existing_binary_features:\n",
272
+ " positive_rates = df_new_data_encoded[existing_binary_features].mean()\n",
273
+ " input_fill_rate_series.update(positive_rates)\n",
274
+ "\n",
275
+ " df_input_fill_rates = input_fill_rate_series.reset_index()\n",
276
+ " df_input_fill_rates.columns = ['FEATURE', 'INPUT_FILL_RATE']\n",
277
+ "\n",
278
+ " # Prepare training data fill rates for comparison.\n",
279
+ " df_training_fill_rates.columns = df_training_fill_rates.columns.str.upper()\n",
280
+ " df_training_fill_rates = df_training_fill_rates.rename(columns={'FEATURE_NAME': 'FEATURE'})\n",
281
+ " if 'POSITIVE_RATE_PERCENT' in df_training_fill_rates.columns:\n",
282
+ " df_training_fill_rates['TRAINING_FILL_RATE'] = df_training_fill_rates['POSITIVE_RATE_PERCENT'] / 100.0\n",
283
+ "\n",
284
+ " # Merge input and training fill rates on the feature name.\n",
285
+ " df_training_fill_rates['FEATURE'] = df_training_fill_rates['FEATURE'].str.upper()\n",
286
+ " df_input_fill_rates['FEATURE'] = df_input_fill_rates['FEATURE'].str.upper()\n",
287
+ "\n",
288
+ " df_comparison = pd.merge(\n",
289
+ " df_training_fill_rates[['FEATURE', 'TRAINING_FILL_RATE']],\n",
290
+ " df_input_fill_rates, on='FEATURE', how='outer'\n",
291
+ " )\n",
292
+ " df_comparison['RUN_ID'] = RUN_ID\n",
293
+ " df_comparison['LAST_RUN'] = pd.Timestamp(datetime.utcnow())\n",
294
+ " df_comparison['FILL_RATE_DIFFERENCE'] = df_comparison['INPUT_FILL_RATE'] - df_comparison['TRAINING_FILL_RATE']\n",
295
+ " df_comparison = df_comparison[['RUN_ID', 'FEATURE', 'TRAINING_FILL_RATE', 'INPUT_FILL_RATE', 'FILL_RATE_DIFFERENCE', 'LAST_RUN']]\n",
296
+ " \n",
297
+ " # Save the comparison table to Snowflake.\n",
298
+ " session.write_pandas(df_comparison, FULL_OUTPUT_FILL_RATE_COMPARISON_TABLE, auto_create_table=True, overwrite=True)\n",
299
+ " print(f\"✅ Successfully saved fill rate comparison to {FULL_OUTPUT_FILL_RATE_COMPARISON_TABLE}\")\n",
300
+ " else:\n",
301
+ " print(\"WARNING: Training fill rate data not available. Skipping comparison.\")\n",
302
+ " else:\n",
303
+ " print(\"Fill rate comparison is disabled by configuration. Skipping.\")\n",
304
+ " \n",
305
+ " # --- Stage 5: Generate Predictions & Save ---\n",
306
+ " print(\"\\n--- Stage 5: Generating and saving predictions ---\")\n",
307
+ " predictions_df = pd.DataFrame({\n",
308
+ " \"ENCOUNTER_ID\": df_new_data_encoded['encounter_id'],\n",
309
+ " \"LENGTH_OF_STAY_PRED\": los_model.predict(df_new_data_encoded.reindex(columns=los_features, fill_value=0)),\n",
310
+ " \"READMISSION_PRED\": readmission_model.predict_proba(df_new_data_encoded.reindex(columns=readmission_features, fill_value=0))[:, 1],\n",
311
+ " \"LAST_RUN\": datetime.utcnow()\n",
312
+ " })\n",
313
+ " if le_discharge is not None:\n",
314
+ " discharge_probas = discharge_model.predict_proba(df_new_data_encoded.reindex(columns=discharge_features, fill_value=0))\n",
315
+ " for i, class_label in enumerate(le_discharge.classes_):\n",
316
+ " col_name = f\"DISCHARGE_PRED_PROBA_{class_label.upper().replace(' ', '_').replace('/', '_')}\"\n",
317
+ " predictions_df[col_name] = discharge_probas[:, i]\n",
318
+ " session.write_pandas(predictions_df, FULL_OUTPUT_PREDICTIONS_TABLE, auto_create_table=True, overwrite=True)\n",
319
+ " print(f\"✅ Successfully saved {len(predictions_df)} predictions to {FULL_OUTPUT_PREDICTIONS_TABLE}\")\n",
320
+ "\n",
321
+ " # --- Stage 6: Calculate and Save Evaluation Metrics ---\n",
322
+ " print(\"\\n--- Stage 6: Calculating and saving evaluation metrics ---\")\n",
323
+ " actual_cols_to_get = ['ENCOUNTER_ID', 'DATA_SOURCE', 'YEAR_NBR', 'LENGTH_OF_STAY', 'READMISSION_NUMERATOR', 'READMISSION_DENOMINATOR', 'DISCHARGE_LOCATION']\n",
324
+ " actual_cols = [col for col in actual_cols_to_get if col in df_new_data_pd.columns]\n",
325
+ " eval_df = pd.merge(predictions_df, df_new_data_pd[actual_cols], on=\"ENCOUNTER_ID\", how=\"left\")\n",
326
+ " eval_df.columns = eval_df.columns.str.lower()\n",
327
+ " \n",
328
+ " ACTUAL_LOS_COL_LOWER, ACTUAL_READMISSION_NUM_LOWER, ACTUAL_READMISSION_DENOM_LOWER, ACTUAL_DISCHARGE_COL_LOWER = \"length_of_stay\", \"readmission_numerator\", \"readmission_denominator\", \"discharge_location\"\n",
329
+ " groups_to_iterate = eval_df.groupby(['data_source', 'year_nbr'], dropna=False) if 'data_source' in eval_df.columns and 'year_nbr' in eval_df.columns else [((\"Overall\", -1), eval_df)]\n",
330
+ "\n",
331
+ " for group_key, group_df in groups_to_iterate:\n",
332
+ " data_source_val, year_nbr_val = group_key\n",
333
+ " print(f\"\\n--- Calculating metrics for group: {data_source_val}, {year_nbr_val} ---\")\n",
334
+ " if ACTUAL_LOS_COL_LOWER in group_df.columns and not group_df[ACTUAL_LOS_COL_LOWER].isnull().all():\n",
335
+ " metrics = calculate_regression_metrics(group_df[ACTUAL_LOS_COL_LOWER], group_df[\"length_of_stay_pred\"])\n",
336
+ " metrics['n_samples'] = len(group_df)\n",
337
+ " log_metrics_to_snowflake(session, RUN_ID, model_version, data_source_val, year_nbr_val, \"LENGTH_OF_STAY\", metrics, FULL_METRICS_TABLE)\n",
338
+ " if ACTUAL_READMISSION_NUM_LOWER in group_df.columns and ACTUAL_READMISSION_DENOM_LOWER in group_df.columns:\n",
339
+ " readmission_eval_df = group_df[group_df[ACTUAL_READMISSION_DENOM_LOWER] == 1].copy()\n",
340
+ " if not readmission_eval_df.empty and not readmission_eval_df[ACTUAL_READMISSION_NUM_LOWER].isnull().all():\n",
341
+ " metrics = calculate_binary_classification_proba_metrics(readmission_eval_df[ACTUAL_READMISSION_NUM_LOWER], readmission_eval_df[\"readmission_pred\"])\n",
342
+ " metrics['n_samples'] = len(readmission_eval_df)\n",
343
+ " log_metrics_to_snowflake(session, RUN_ID, model_version, data_source_val, year_nbr_val, \"READMISSION\", metrics, FULL_METRICS_TABLE)\n",
344
+ " if ACTUAL_DISCHARGE_COL_LOWER in group_df.columns and le_discharge is not None:\n",
345
+ " group_input = df_new_data_encoded.loc[group_df.index]\n",
346
+ " y_true_labels = group_df[ACTUAL_DISCHARGE_COL_LOWER]\n",
347
+ " known_mask = y_true_labels.isin(le_discharge.classes_)\n",
348
+ " y_true_enc = le_discharge.transform(y_true_labels[known_mask])\n",
349
+ " group_probas = discharge_model.predict_proba(group_input[known_mask].reindex(columns=discharge_features, fill_value=0))\n",
350
+ " group_preds_enc = discharge_model.predict(group_input[known_mask].reindex(columns=discharge_features, fill_value=0))\n",
351
+ " metrics = calculate_multiclass_classification_metrics(y_true_enc, group_preds_enc, group_probas, le_discharge.classes_)\n",
352
+ " metrics['n_samples'] = len(y_true_enc)\n",
353
+ " log_metrics_to_snowflake(session, RUN_ID, model_version, data_source_val, year_nbr_val, \"DISCHARGE_LOCATION_OVERALL\", metrics, FULL_METRICS_TABLE)\n",
354
+ " for detail in metrics.get('per_class_details', []):\n",
355
+ " log_metrics_to_snowflake(session, RUN_ID, model_version, data_source_val, year_nbr_val, f\"DISCHARGE_LOCATION_Class_{detail['class_name']}\", detail, FULL_METRICS_TABLE)\n",
356
+ " print(\"\\n✅ Script finished.\")\n",
357
+ "\n",
358
+ "if __name__ == \"__main__\":\n",
359
+ " try:\n",
360
+ " snowpark_session = get_active_session()\n",
361
+ " print(\"Successfully retrieved active Snowpark session.\")\n",
362
+ " except SnowparkClientException:\n",
363
+ " print(\"No active session. Creating a new session from local credentials...\")\n",
364
+ " snowpark_session = Session.builder.create()\n",
365
+ " main(snowpark_session)"
366
+ ]
367
+ }
368
+ ],
369
+ "metadata": {
370
+ "kernelspec": {
371
+ "display_name": "Streamlit Notebook",
372
+ "name": "streamlit"
373
+ },
374
+ "lastEditStatus": {
375
+ "authorEmail": "[email protected]",
376
+ "authorId": "374530764978",
377
+ "authorName": "BRAD",
378
+ "lastEditTime": 1750882358404,
379
+ "notebookId": "7e7pzs6ti4k6chxub4nt",
380
+ "sessionId": "56467df4-1029-4269-83ed-9a238cb180f6"
381
+ }
382
+ },
383
+ "nbformat": 4,
384
+ "nbformat_minor": 5
385
+ }