bradmontierth
commited on
Commit
·
8cd7f61
0
Parent(s):
Initial commit
Browse files- .gitattributes +35 -0
- README.md +205 -0
- Train Tuva Concurrent Inpatient Models.ipynb +31 -0
- feature_fill_rate_inpatient.csv +0 -0
- inpatient_feature_importance.csv +0 -0
- inpatient_models_bundle_medicare_lds_2023_fs.pkl.gz +3 -0
- inpatient_models_eval_metrics.csv +13 -0
- predict inpatient.ipynb +385 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
tags:
|
4 |
+
- healthcare
|
5 |
+
- medicare
|
6 |
+
- xgboost
|
7 |
+
- logistic-regression
|
8 |
+
- length-of-stay
|
9 |
+
- readmission
|
10 |
+
- discharge-prediction
|
11 |
+
- classification
|
12 |
+
- regression
|
13 |
+
datasets:
|
14 |
+
- cms-lds
|
15 |
+
|
16 |
+
---
|
17 |
+
|
18 |
+
# Medicare Inpatient Outcome Prediction Models
|
19 |
+
|
20 |
+
**Model ID:** Medicare LDS 2023 Inpatient Model Bundle
|
21 |
+
**Model Types:** XGBoost Regressor & Calibrated Multinomial Logistic Regression
|
22 |
+
**Dataset:** 2023 CMS Limited Data Set (LDS)
|
23 |
+
**Target Level:** Inpatient Encounter
|
24 |
+
|
25 |
+
---
|
26 |
+
|
27 |
+
## What the Model Predicts
|
28 |
+
|
29 |
+
This bundle contains three distinct models that predict key outcomes for a given inpatient hospital stay:
|
30 |
+
|
31 |
+
- **Length of Stay (Regression):** Predicts the total number of days for the inpatient stay.
|
32 |
+
- **Readmission Probability (Binary Classification):** Predicts the probability (from 0.0 to 1.0) that the patient will be readmitted to a hospital within 30 days of discharge.
|
33 |
+
- **Discharge Location Probability (Multiclass Classification):** Predicts the probability for each possible discharge location (e.g., Home, Skilled Nursing Facility, Hospice).
|
34 |
+
|
35 |
+
---
|
36 |
+
|
37 |
+
## Intended Use
|
38 |
+
|
39 |
+
This model bundle is designed to support a variety of clinical and operational workflows:
|
40 |
+
|
41 |
+
- **Discharge Planning & Care Management:** Identify high-risk patients who may need additional support or a specific type of post-acute care to prevent readmission.
|
42 |
+
- **Resource Planning:** Forecast bed-days and resource utilization based on predicted length of stay.
|
43 |
+
- **Actuarial Analysis:** Inform risk stratification and cost estimation models.
|
44 |
+
- **Benchmarking:** Compare observed outcomes against predicted risks for a given patient population.
|
45 |
+
- **Healthcare Research:** Analyze drivers of inpatient outcomes.
|
46 |
+
|
47 |
+
> **Note on Prediction Type:** The models are trained for **concurrent prediction** — they use clinical and demographic data available during the inpatient stay to predict outcomes related to that same stay.
|
48 |
+
|
49 |
+
---
|
50 |
+
|
51 |
+
## Model Performance
|
52 |
+
|
53 |
+
> These metrics reflect performance on a **20% test set** held out from the 2023 CMS LDS data. All values were calculated on unseen data and represent model generalization performance.
|
54 |
+
|
55 |
+
### Model 1: Length of Stay (XGBoost Regressor)
|
56 |
+
|
57 |
+
| Target | R² | MAE (days) |
|
58 |
+
|--------------------|------|------------|
|
59 |
+
| `length_of_stay` | 0.25 | 2.72 |
|
60 |
+
|
61 |
+
### Model 2: Readmission Probability (Calibrated Logistic Regression)
|
62 |
+
|
63 |
+
| Target | AUC ROC | Brier Score |
|
64 |
+
|--------------------------|---------|-------------|
|
65 |
+
| `readmission_probability`| 0.7483 | 0.1176 |
|
66 |
+
|
67 |
+
- **AUC ROC:** Measures the model's ability to distinguish between patients who will and will not be readmitted (higher is better).
|
68 |
+
- **Brier Score:** Measures the accuracy of the predicted probabilities (lower is better).
|
69 |
+
|
70 |
+
### Model 3: Discharge Location (Calibrated Logistic Regression)
|
71 |
+
|
72 |
+
| Target | Accuracy | Brier Score (Macro Avg) |
|
73 |
+
|----------------------|----------|--------------------------|
|
74 |
+
| `discharge_location` | 0.5216 | 0.0771 |
|
75 |
+
|
76 |
+
- **Accuracy:** The overall percentage of times the model predicted the correct discharge location.
|
77 |
+
- **Brier Score (Macro Avg):** The average Brier Score across all possible discharge location classes (lower is better).
|
78 |
+
|
79 |
+
---
|
80 |
+
|
81 |
+
## Files Included
|
82 |
+
|
83 |
+
- `inpatient_models_bundle_medicare_lds_2023_fs.pkl` — A compressed pickle file containing the trained models, feature lists, and encoders. The filename may vary based on training parameters (e.g., `_fs` indicates feature selection was used). The bundle includes:
|
84 |
+
- `los_model` (XGBoost)
|
85 |
+
- `readmission_model` (Calibrated Logistic Regression)
|
86 |
+
- `discharge_model` (Calibrated Logistic Regression)
|
87 |
+
- `feature_columns_*` (specific feature lists for each model)
|
88 |
+
- `le_discharge` (label encoder for discharge location)
|
89 |
+
- `Train Tuva Concurrent Inpatient Models.ipynb` — The notebook script used for training, feature selection, and evaluation on Snowflake.
|
90 |
+
- `predict_inpatient.py` — An example prediction script for running the bundle on new data.
|
91 |
+
- `feature_fill_rate_inpatient.csv` — A diagnostic file detailing the prevalence of each feature in the training dataset.
|
92 |
+
- `inpatient_feature_importance.csv` — A file containing the calculated importance of each feature for each of the three models.
|
93 |
+
|
94 |
+
---
|
95 |
+
|
96 |
+
## Understanding Model Artifacts
|
97 |
+
|
98 |
+
This repository includes two key CSV files that provide insight into the model's training data and internal logic. These are generated by the training notebook, which also populates corresponding tables in Snowflake for easier querying (`FEATURE_FREQUENCY_STATS_INPATIENT` and `MODEL_FEATURE_IMPORTANCE_INPATIENT`).
|
99 |
+
|
100 |
+
### Feature Fill Rates (`feature_fill_rate_inpatient.csv`)
|
101 |
+
|
102 |
+
This file is a diagnostic tool for understanding the input data used to train the models. It helps you check for data drift or data quality issues.
|
103 |
+
|
104 |
+
| Column | Description |
|
105 |
+
|---|---|
|
106 |
+
| `FEATURE_NAME` | The name of the input feature (e.g., `age_at_admit`, `cond_hypertension`). |
|
107 |
+
| `POSITIVE_COUNT` | The number of records in the training set where this feature was present (non-zero). |
|
108 |
+
| `TOTAL_ROWS` | The total number of records in the training set. |
|
109 |
+
| `POSITIVE_RATE_PERCENT` | The prevalence or "fill rate" of the feature (`POSITIVE_COUNT` / `TOTAL_ROWS`). |
|
110 |
+
|
111 |
+
**How to Use:** Compare the `POSITIVE_RATE_PERCENT` from this file with the rates from your own prediction input data. Significant discrepancies can point to data pipeline issues and may explain poor model performance.
|
112 |
+
|
113 |
+
### Feature Importances (`inpatient_feature_importance.csv`)
|
114 |
+
|
115 |
+
This file provides model explainability by showing which features are most influential for each of the three models.
|
116 |
+
|
117 |
+
| Column | Description |
|
118 |
+
|---|---|
|
119 |
+
| `MODEL_NAME` | Identifies the model (e.g., `Inpatient_LOS_FeatureSelected`). |
|
120 |
+
| `FEATURE_NAME` | The name of the input feature. |
|
121 |
+
| `IMPORTANCE_VALUE` | A numeric score indicating the feature's influence. Higher is more important. |
|
122 |
+
| `IMPORTANCE_RANK` | The rank of the feature's importance for that specific model (1 is most important). |
|
123 |
+
|
124 |
+
**How to Use:** Use this file to understand the key drivers behind the model's predictions. For example, you can filter by `MODEL_NAME` for the readmission model and sort by `IMPORTANCE_RANK` to see what most influences readmission risk. This is useful for clinical validation and debugging.
|
125 |
+
|
126 |
+
---
|
127 |
+
|
128 |
+
## Quick Start: End-to-End Workflow
|
129 |
+
|
130 |
+
This section provides high-level instructions for running a model with the Tuva Project. The workflow involves preparing benchmark data using dbt, running a Python prediction script, and optionally ingesting the results back into dbt for analysis.
|
131 |
+
|
132 |
+
### 1. Configure Your dbt Project
|
133 |
+
|
134 |
+
You need to enable the correct variables in your `dbt_project.yml` file to control the workflow.
|
135 |
+
|
136 |
+
#### A. Enable Benchmark Marts
|
137 |
+
|
138 |
+
These two variables control which parts of the Tuva Project are active. They are `false` by default.
|
139 |
+
|
140 |
+
```yaml
|
141 |
+
# in dbt_project.yml
|
142 |
+
vars:
|
143 |
+
benchmarks_train: true
|
144 |
+
benchmarks_already_created: true
|
145 |
+
```
|
146 |
+
|
147 |
+
- `benchmarks_train`: Set to `true` to build the datasets that the ML models will use for making predictions.
|
148 |
+
- `benchmarks_already_created`: Set to `true` to ingest model predictions back into the project as a new dbt source.
|
149 |
+
|
150 |
+
#### B. (Optional) Set Prediction Source Locations
|
151 |
+
|
152 |
+
If you plan to bring predictions back into dbt for analysis, you must define where dbt can find the prediction data.
|
153 |
+
|
154 |
+
```yaml
|
155 |
+
# in dbt_project.yml
|
156 |
+
vars:
|
157 |
+
predictions_person_year: "{{ source('benchmark_output', 'person_year') }}"
|
158 |
+
predictions_inpatient: "{{ source('benchmark_output', 'inpatient') }}"
|
159 |
+
```
|
160 |
+
|
161 |
+
#### C. Configure `sources.yml`
|
162 |
+
|
163 |
+
Ensure your `sources.yml` file includes a definition for the source you referenced above (e.g., `benchmark_output`) that points to the database and schema where your model's prediction outputs are stored.
|
164 |
+
|
165 |
+
---
|
166 |
+
|
167 |
+
### 2. The 3-Step Run Process
|
168 |
+
|
169 |
+
This workflow can be managed by any orchestration tool (e.g., Airflow, Prefect, Fabric Notebooks) or run manually from the command line.
|
170 |
+
|
171 |
+
#### Step 1: Generate the Training & Benchmarking Data
|
172 |
+
|
173 |
+
Run the Tuva Project with `benchmarks_train` enabled. This creates the input data required by the ML model.
|
174 |
+
|
175 |
+
```bash
|
176 |
+
dbt build --vars '{benchmarks_train: true}'
|
177 |
+
```
|
178 |
+
|
179 |
+
To run only the benchmark mart:
|
180 |
+
|
181 |
+
```bash
|
182 |
+
dbt build --select tag:benchmarks_train --vars '{benchmarks_train: true}'
|
183 |
+
```
|
184 |
+
|
185 |
+
#### Step 2: Run the Prediction Python Code
|
186 |
+
|
187 |
+
Execute the Python script `(predict inpatient.ipynb)` to generate predictions. This script will read the data created in Step 1 and write the prediction outputs to a persistent location (e.g., a table in your data warehouse).
|
188 |
+
|
189 |
+
*We have provided example Snowflake Notebook code within each model's repository that was used in Tuva's environment.*
|
190 |
+
|
191 |
+
#### Step 3: (Optional) Bring Predictions back into Tuva Project
|
192 |
+
|
193 |
+
To bring the predictions back into the Tuva Project for analysis, run dbt again with `benchmarks_already_created` enabled. This populates the analytics marts.
|
194 |
+
|
195 |
+
```bash
|
196 |
+
dbt build --vars '{benchmarks_already_created: true, benchmarks_train: false}'
|
197 |
+
```
|
198 |
+
|
199 |
+
To run only the analysis models:
|
200 |
+
|
201 |
+
```bash
|
202 |
+
dbt build --select tag:benchmarks_analysis --vars '{benchmarks_already_created: true, benchmarks_train: false}'
|
203 |
+
```
|
204 |
+
|
205 |
+
---
|
Train Tuva Concurrent Inpatient Models.ipynb
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"kernelspec": {
|
4 |
+
"display_name": "Streamlit Notebook",
|
5 |
+
"name": "streamlit"
|
6 |
+
},
|
7 |
+
"lastEditStatus": {
|
8 |
+
"notebookId": "6rovstl42ft2p5id6gwo",
|
9 |
+
"authorId": "374530764978",
|
10 |
+
"authorName": "BRAD",
|
11 |
+
"authorEmail": "[email protected]",
|
12 |
+
"sessionId": "65561efa-4d18-4072-8f4d-10240cb902ba",
|
13 |
+
"lastEditTime": 1750870004305
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"nbformat_minor": 5,
|
17 |
+
"nbformat": 4,
|
18 |
+
"cells": [
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"id": "3775908f-ca36-4846-8f38-5adca39217f2",
|
22 |
+
"metadata": {
|
23 |
+
"language": "python",
|
24 |
+
"name": "cell1"
|
25 |
+
},
|
26 |
+
"source": "0]}\")\n \n base_readmit_model = LogisticRegression(class_weight='balanced', random_state=42, max_iter=1000, solver='liblinear')\n \n # Determine the feature set to use.\n if FAST_MODE:\n print(\"\\n[FAST MODE] Skipping feature selection. Using all available features.\")\n best_readmission_features = X_train_base.columns.tolist()\n else:\n best_readmission_features = find_best_feature_subset(\n model=base_readmit_model, X_train=X_train_base, y_train=y_train_base, X_val=X_calib_read, y_val=y_calib_read,\n scoring_func=roc_auc_score, higher_is_better=True, model_name=\"Readmission (Logistic Regression)\"\n )\n\n print(f\"\\nTraining final Readmission model pipeline using {len(best_readmission_features)} features...\")\n base_model_for_calib = clone(base_readmit_model)\n base_model_for_calib.fit(X_train_base[best_readmission_features], y_train_base)\n \n # Log feature importances from the base (uncalibrated) model.\n uncal_read_model_name = f\"Inpatient_Readmission_Base_Uncalibrated_{MODEL_NAME_SUFFIX}{EXCLUSION_SUFFIX}\"\n log_feature_importances_to_snowflake(session, base_model_for_calib, best_readmission_features, MODEL_RUN_ID, uncal_read_model_name, TARGET_READMISSION, FEATURE_IMPORTANCE_TABLE_NAME)\n \n # Evaluate and log metrics for the uncalibrated model for comparison.\n y_pred_proba_uncal = base_model_for_calib.predict_proba(X_test_read[best_readmission_features])[:, 1]\n uncalibrated_metrics = calculate_binary_classification_proba_metrics(y_test_read, y_pred_proba_uncal)\n log_model_metrics_to_snowflake(session, MODEL_RUN_ID, uncal_read_model_name, TARGET_READMISSION + \"_Probability\", uncalibrated_metrics, \"Binary_Uncalibrated\", METRICS_TABLE_NAME, MODEL_SOURCE_TAG, MODEL_YEAR_TAG)\n \n # Calibrate the model on the held-out calibration set.\n calibrated_readmission_model = CalibratedClassifierCV(base_model_for_calib, method='isotonic', cv='prefit')\n calibrated_readmission_model.fit(X_calib_read[best_readmission_features], y_calib_read)\n y_pred_proba_cal = calibrated_readmission_model.predict_proba(X_test_read[best_readmission_features])[:, 1]\n\n print(\"\\nCalibrated Readmission Model - Test Set Evaluation:\")\n calibrated_proba_metrics = calculate_binary_classification_proba_metrics(y_test_read, y_pred_proba_cal)\n for k, v in calibrated_proba_metrics.items(): print(f\" {k}: {v:.4f}\")\n \n cal_read_model_name = f\"Inpatient_Readmission_Calibrated_{MODEL_NAME_SUFFIX}{EXCLUSION_SUFFIX}\"\n log_model_metrics_to_snowflake(session, MODEL_RUN_ID, cal_read_model_name, TARGET_READMISSION + \"_Probability\", calibrated_proba_metrics, \"Binary_Calibrated\", METRICS_TABLE_NAME, MODEL_SOURCE_TAG, MODEL_YEAR_TAG)\n\n# --- 4.3 Model 3: Predicting Discharge Location (Multiclass Classification) ---\nprint(\"\\n\" + \"=\"*80)\nprint(\"--- Training Model 3: Calibrated Discharge Location ---\")\nprint(\"=\"*80)\nTARGET_DISCHARGE = 'discharge_location'\ncalibrated_discharge_model, le_discharge, best_discharge_features = None, None, None\n\nif TARGET_DISCHARGE not in df_pd.columns:\n print(f\"Error: Target column '{TARGET_DISCHARGE}' not found. Skipping Discharge Location model.\")\nelse:\n le_discharge = LabelEncoder()\n y_discharge_encoded = le_discharge.fit_transform(df_pd[TARGET_DISCHARGE])\n num_classes_discharge = len(le_discharge.classes_)\n print(f\"Discharge Location: {num_classes_discharge} classes found: {le_discharge.classes_}\")\n \n # Split data: 60% base train, 20% calibration, 20% test\n stratify_discharge = y_discharge_encoded if num_classes_discharge > 1 else None\n X_train_full_disc, X_test_disc, y_train_full_disc_enc, y_test_disc_enc = train_test_split(X, y_discharge_encoded, test_size=0.2, random_state=42, stratify=stratify_discharge)\n X_train_base_disc, X_calib_disc, y_train_base_disc_enc, y_calib_disc_enc = train_test_split(X_train_full_disc, y_train_full_disc_enc, test_size=0.25, random_state=42, stratify=y_train_full_disc_enc if num_classes_discharge > 1 else None)\n print(f\"Data split for discharge: Base train: {X_train_base_disc.shape[0]}, Calibration: {X_calib_disc.shape[0]}, Test: {X_test_disc.shape[0]}\")\n \n base_discharge_model = LogisticRegression(random_state=42, max_iter=1000, solver='lbfgs', multi_class='multinomial', class_weight='balanced')\n \n # Determine the feature set to use.\n if FAST_MODE:\n print(\"\\n[FAST MODE] Skipping feature selection. Using all available features.\")\n best_discharge_features = X_train_base_disc.columns.tolist()\n else:\n best_discharge_features = find_best_feature_subset(\n model=base_discharge_model, X_train=X_train_base_disc, y_train=y_train_base_disc_enc, X_val=X_calib_disc, y_val=y_calib_disc_enc,\n scoring_func=log_loss, higher_is_better=False, model_name=\"Discharge Location (Multinomial Regression)\"\n )\n\n print(f\"\\nTraining final Discharge Location model pipeline using {len(best_discharge_features)} features...\")\n base_model_for_calib_disc = clone(base_discharge_model)\n base_model_for_calib_disc.fit(X_train_base_disc[best_discharge_features], y_train_base_disc_enc)\n \n discharge_model_name = f\"Inpatient_Discharge_Cal_Overall_{MODEL_NAME_SUFFIX}{EXCLUSION_SUFFIX}\"\n log_feature_importances_to_snowflake(session, base_model_for_calib_disc, best_discharge_features, MODEL_RUN_ID, discharge_model_name, TARGET_DISCHARGE, FEATURE_IMPORTANCE_TABLE_NAME)\n \n # Calibrate the model. 'sigmoid' is used for one-vs-rest calibration, suitable for multiclass.\n calibrated_discharge_model = CalibratedClassifierCV(base_model_for_calib_disc, method='sigmoid', cv='prefit')\n calibrated_discharge_model.fit(X_calib_disc[best_discharge_features], y_calib_disc_enc)\n y_pred_proba_discharge_calibrated = calibrated_discharge_model.predict_proba(X_test_disc[best_discharge_features])\n y_pred_labels_discharge_calibrated = calibrated_discharge_model.predict(X_test_disc[best_discharge_features])\n \n print(\"\\nCalibrated Discharge Model - Test Set Evaluation:\")\n calibrated_disc_metrics = calculate_multiclass_classification_metrics(y_test_disc_enc, y_pred_labels_discharge_calibrated, y_pred_proba_discharge_calibrated, le_discharge.classes_)\n \n # Log the overall multiclass metrics.\n overall_cal_metrics_to_log = {k: v for k, v in calibrated_disc_metrics.items() if k != 'per_class_details'}\n overall_cal_metrics_to_log['BRIER_SCORE'] = calibrated_disc_metrics.get('BRIER_SCORE_MACRO_AVG')\n log_model_metrics_to_snowflake(session, MODEL_RUN_ID, discharge_model_name, TARGET_DISCHARGE, overall_cal_metrics_to_log, \"Multiclass_Cal_Overall\", METRICS_TABLE_NAME, MODEL_SOURCE_TAG, MODEL_YEAR_TAG)\n \n # --- FIX: Log the per-class metrics by mapping keys correctly ---\n discharge_class_model_name = f\"Inpatient_Discharge_Cal_Class_{MODEL_NAME_SUFFIX}{EXCLUSION_SUFFIX}\"\n for class_detail in calibrated_disc_metrics.get('per_class_details', []):\n # Create a new dict with keys the logging function expects.\n per_class_metrics_to_log = {\n 'BRIER_SCORE': class_detail.get('brier_score'),\n 'AVG_Y_PRED': class_detail.get('avg_pred_proba'),\n 'AVG_Y_TRUE': class_detail.get('true_proportion'),\n 'PRED_RATIO': class_detail.get('proba_ratio'),\n }\n log_model_metrics_to_snowflake(\n session, MODEL_RUN_ID, discharge_class_model_name,\n f\"{TARGET_DISCHARGE}_Class_{class_detail['class_name']}\",\n per_class_metrics_to_log, # Use the correctly mapped dictionary\n \"Multiclass_Cal_ClassDetail\", METRICS_TABLE_NAME, MODEL_SOURCE_TAG, MODEL_YEAR_TAG\n )\n\n print(\"\\nCalibrated Classification Report:\\n\", classification_report(y_test_disc_enc, y_pred_labels_discharge_calibrated, target_names=le_discharge.classes_.astype(str), zero_division=0, digits=4))\n\n\n# =============================================================================\n# 5. MODEL SAVING\n# =============================================================================\nprint(\"\\n\" + \"=\"*80)\nprint(\"--- Saving Models and Artifacts ---\")\nprint(\"=\"*80)\n\n# Bundle all necessary objects for deployment into a single dictionary.\ninpatient_models_bundle = {\n 'los_model': los_model,\n 'readmission_model': calibrated_readmission_model,\n 'discharge_model': calibrated_discharge_model,\n 'feature_columns_los': best_los_features,\n 'feature_columns_readmission': best_readmission_features,\n 'feature_columns_discharge': best_discharge_features,\n 'le_discharge': le_discharge,\n 'model_run_id': MODEL_RUN_ID,\n 'fast_mode': FAST_MODE,\n 'excluded_feature_prefixes': EXCLUDE_FEATURE_PREFIXES\n}\n\n# Create a descriptive file name for the bundle.\nBUNDLE_SUFFIX = \"fast\" if FAST_MODE else \"fs\"\nEXCLUSION_FILE_TAG = f\"_excl_{'-'.join([p.strip('_').lower() for p in EXCLUDE_FEATURE_PREFIXES])}\" if EXCLUDE_FEATURE_PREFIXES else \"\"\nBUNDLE_FILE_NAME = f'inpatient_models_bundle_{MODEL_SOURCE_TAG}_{MODEL_YEAR_TAG}_{BUNDLE_SUFFIX}{EXCLUSION_FILE_TAG}.pkl'\n\n# Save the bundle locally using pickle.\nwith open(BUNDLE_FILE_NAME, 'wb') as f:\n pickle.dump(inpatient_models_bundle, f)\nprint(f\"Models bundled and saved locally to: {BUNDLE_FILE_NAME}\")\n\n# Upload the local bundle file to the specified Snowflake stage.\nput_result = session.file.put(BUNDLE_FILE_NAME, SNOWFLAKE_STAGE_NAME, overwrite=True)\nif put_result[0].status == 'UPLOADED':\n print(f\"Model bundle successfully uploaded to Snowflake stage: {SNOWFLAKE_STAGE_NAME}\")\nelse:\n print(f\"Error uploading model bundle. Status: {put_result[0].status}, Message: {put_result[0].message}\")\n\nfile_size_mb = os.path.getsize(BUNDLE_FILE_NAME) / (1024 * 1024)\nprint(f\"Saved local bundle file size: {file_size_mb:.2f} MB\")\n\nprint(f\"\\n✅ Script finished ({'FAST MODE' if FAST_MODE else 'FULL MODE'}).\")",
|
27 |
+
"execution_count": null,
|
28 |
+
"outputs": []
|
29 |
+
}
|
30 |
+
]
|
31 |
+
}
|
feature_fill_rate_inpatient.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
inpatient_feature_importance.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
inpatient_models_bundle_medicare_lds_2023_fs.pkl.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:afb9b33cac24e91958adfc4f52ed3a1fd617a51c12bbf9d8be2c37eab4eb979c
|
3 |
+
size 1808823
|
inpatient_models_eval_metrics.csv
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_RUN_ID,MODEL_NAME,TARGET_NAME,R2,MAE,MSE,PRED_RATIO,MAE_PERCENT,AUC_ROC,AUC_PR,LOG_LOSS,BRIER_SCORE,ACCURACY,AVG_Y_PRED,AVG_Y_TRUE,MODEL_SOURCE,MODEL_TYPE,MODEL_YEAR,EVAL_TS
|
2 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_transfer/other facility,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:12:05.239
|
3 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_snf,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:12:01.767
|
4 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_other,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:58.290
|
5 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_ipt rehab,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:55.442
|
6 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_hospice,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:52.653
|
7 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_home health,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:49.518
|
8 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_home,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:46.588
|
9 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Class_FeatureSelected,discharge_location_Class_expired,,,,,,,,,,,,,medicare_lds,Multiclass_Cal_ClassDetail,2023,2025-06-18 21:11:43.545
|
10 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Discharge_Cal_Overall_FeatureSelected,discharge_location,,,,,,,,1.302877,0.077095,0.521647,,,medicare_lds,Multiclass_Cal_Overall,2023,2025-06-18 21:11:40.194
|
11 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Readmission_Calibrated_FeatureSelected,readmission_numerator_Probability,,,,,,0.74825,0.334959,0.380948,0.117589,,0.157001,0.15629,medicare_lds,Binary_Classification_Probability_Calibrated,2023,2025-06-18 20:48:33.051
|
12 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_Readmission_Base_Uncalibrated_FeatureSelected,readmission_numerator_Probability,,,,,,0.748529,0.341185,0.594258,0.20216,,0.428988,0.15629,medicare_lds,Binary_Classification_Probability_Uncalibrated,2023,2025-06-18 20:48:29.072
|
13 |
+
03daf6f5-4a7a-44b9-a670-e0520ec6772f,Inpatient_LOS_FeatureSelected,length_of_stay,0.245662,2.724616,23.875849,0.970435,54.179359,,,,,,4.880204201,5.028881,medicare_lds,Regression,2023,2025-06-18 20:46:30.898
|
predict inpatient.ipynb
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "3775908f-ca36-4846-8f38-5adca39217f2",
|
7 |
+
"metadata": {
|
8 |
+
"language": "python",
|
9 |
+
"name": "cell1"
|
10 |
+
},
|
11 |
+
"outputs": [],
|
12 |
+
"source": [
|
13 |
+
"\"\"\"\n",
|
14 |
+
"Snowflake Inpatient Prediction and Evaluation Script (Self-Contained)\n",
|
15 |
+
"\n",
|
16 |
+
"This script is designed to run entirely within a Snowflake environment with NO\n",
|
17 |
+
"external network access. It performs the following operations:\n",
|
18 |
+
"\n",
|
19 |
+
"1. Connects to an active Snowflake session.\n",
|
20 |
+
"2. Loads a pre-trained model bundle and reference tables from a Snowflake stage.\n",
|
21 |
+
"3. Loads new inpatient data from a Snowflake table for prediction.\n",
|
22 |
+
"4. Calculates and compares the feature fill rates of the input data against the training data.\n",
|
23 |
+
"5. Performs data preprocessing, including one-hot encoding.\n",
|
24 |
+
"6. Generates predictions for Length of Stay, Readmission, and Discharge Location.\n",
|
25 |
+
"7. Calculates evaluation metrics by comparing predictions to actual outcomes.\n",
|
26 |
+
"8. Saves predictions, metrics, and diagnostic tables back to Snowflake.\n",
|
27 |
+
"\"\"\"\n",
|
28 |
+
"\n",
|
29 |
+
"import os\n",
|
30 |
+
"import gzip\n",
|
31 |
+
"import pickle\n",
|
32 |
+
"import uuid\n",
|
33 |
+
"from datetime import datetime\n",
|
34 |
+
"\n",
|
35 |
+
"import numpy as np\n",
|
36 |
+
"import pandas as pd\n",
|
37 |
+
"from snowflake.snowpark.context import get_active_session\n",
|
38 |
+
"from snowflake.snowpark.session import Session\n",
|
39 |
+
"from snowflake.snowpark.exceptions import SnowparkClientException\n",
|
40 |
+
"\n",
|
41 |
+
"from sklearn.metrics import (\n",
|
42 |
+
" r2_score, mean_absolute_error, mean_squared_error, accuracy_score,\n",
|
43 |
+
" roc_auc_score, log_loss, brier_score_loss, average_precision_score\n",
|
44 |
+
")\n",
|
45 |
+
"from snowflake.snowpark.types import (\n",
|
46 |
+
" StructType, StructField, StringType, TimestampType, FloatType, LongType\n",
|
47 |
+
")\n",
|
48 |
+
"\n",
|
49 |
+
"# =============================================================================\n",
|
50 |
+
"# 0. CONFIGURATION\n",
|
51 |
+
"# =============================================================================\n",
|
52 |
+
"\n",
|
53 |
+
"# --- Snowflake Environment Settings ---\n",
|
54 |
+
"DATABASE = \"CMS_SYNTHETIC\"\n",
|
55 |
+
"SCHEMA = \"BENCHMARKS\"\n",
|
56 |
+
"STAGE_DATABASE = \"CMS_SYNTHETIC\"\n",
|
57 |
+
"\n",
|
58 |
+
"# --- Input & Output Table Names ---\n",
|
59 |
+
"INPUT_TABLE_NAME = \"BENCHMARKS_INPATIENT_INPUT\"\n",
|
60 |
+
"OUTPUT_PREDICTIONS_TABLE_NAME = \"INPATIENT_PREDICTIONS\"\n",
|
61 |
+
"OUTPUT_METRICS_TABLE_NAME = \"INPATIENT_EVALUATION_METRICS\"\n",
|
62 |
+
"OUTPUT_FILL_RATE_COMPARISON_TABLE_NAME = \"INPATIENT_FILL_RATE_COMPARISON\"\n",
|
63 |
+
"\n",
|
64 |
+
"# --- Model & Artifact Loading Configuration ---\n",
|
65 |
+
"MODEL_STAGE_NAME = f\"@{STAGE_DATABASE}.{SCHEMA}.BENCHMARK_STAGE\"\n",
|
66 |
+
"MODEL_FILE_NAME_IN_STAGE = \"inpatient_models_bundle_medicare_lds_2023_fs.pkl.gz\"\n",
|
67 |
+
"FILL_RATE_FILE_NAME_IN_STAGE = \"feature_fill_rate_inpatient.csv.gz\"\n",
|
68 |
+
"\n",
|
69 |
+
"# --- Local Directory for Artifacts ---\n",
|
70 |
+
"# This directory is used within the Snowflake virtual environment to store downloaded files.\n",
|
71 |
+
"LOCAL_ARTIFACT_DIR = \"/tmp/appRoot\"\n",
|
72 |
+
"\n",
|
73 |
+
"# --- Run Configuration ---\n",
|
74 |
+
"ROW_LIMIT = None # Set to an integer to limit input rows for testing, or None for all rows.\n",
|
75 |
+
"RUN_ID = str(uuid.uuid4())\n",
|
76 |
+
"\n",
|
77 |
+
"# --- Feature Configuration ---\n",
|
78 |
+
"# Set to True to calculate and save a comparison of feature fill rates\n",
|
79 |
+
"# between the input data and the original training data. This requires\n",
|
80 |
+
"# the fill rate file to be available in the stage. Set to False to skip.\n",
|
81 |
+
"ENABLE_FILL_RATE_COMPARISON = True\n",
|
82 |
+
"\n",
|
83 |
+
"# --- Construct Full Object Names ---\n",
|
84 |
+
"FULL_INPUT_TABLE = f\"{INPUT_TABLE_NAME}\"\n",
|
85 |
+
"FULL_OUTPUT_PREDICTIONS_TABLE = f\"{OUTPUT_PREDICTIONS_TABLE_NAME}\"\n",
|
86 |
+
"FULL_METRICS_TABLE = f\"{OUTPUT_METRICS_TABLE_NAME}\"\n",
|
87 |
+
"FULL_OUTPUT_FILL_RATE_COMPARISON_TABLE = f\"{OUTPUT_FILL_RATE_COMPARISON_TABLE_NAME}\"\n",
|
88 |
+
"MODEL_STAGE_PATH = f\"{MODEL_STAGE_NAME}/{MODEL_FILE_NAME_IN_STAGE}\"\n",
|
89 |
+
"FILL_RATE_STAGE_PATH = f\"{MODEL_STAGE_NAME}/{FILL_RATE_FILE_NAME_IN_STAGE}\"\n",
|
90 |
+
"\n",
|
91 |
+
"\n",
|
92 |
+
"# =============================================================================\n",
|
93 |
+
"# 1. UTILITY FUNCTIONS\n",
|
94 |
+
"# =============================================================================\n",
|
95 |
+
"\n",
|
96 |
+
"def calculate_regression_metrics(y_true, y_pred):\n",
|
97 |
+
" \"\"\"Calculates a dictionary of standard regression metrics.\"\"\"\n",
|
98 |
+
" y_true_np, y_pred_np = np.array(y_true), np.array(y_pred)\n",
|
99 |
+
" sum_y_true, mean_y_true = np.sum(y_true_np), np.mean(y_true_np)\n",
|
100 |
+
" pred_ratio = np.sum(y_pred_np) / sum_y_true if sum_y_true != 0 else np.nan\n",
|
101 |
+
" mae_percent = (mean_absolute_error(y_true_np, y_pred_np) / mean_y_true) * 100 if mean_y_true != 0 else np.nan\n",
|
102 |
+
" return {\n",
|
103 |
+
" 'R2': r2_score(y_true_np, y_pred_np), 'MAE': mean_absolute_error(y_true_np, y_pred_np),\n",
|
104 |
+
" 'MSE': mean_squared_error(y_true_np, y_pred_np), 'PRED_RATIO': pred_ratio, 'MAE_PERCENT': mae_percent,\n",
|
105 |
+
" 'AVG_Y_PRED': np.mean(y_pred_np), 'AVG_Y_TRUE': mean_y_true\n",
|
106 |
+
" }\n",
|
107 |
+
"\n",
|
108 |
+
"def calculate_binary_classification_proba_metrics(y_true, y_pred_proba):\n",
|
109 |
+
" \"\"\"Calculates a dictionary of standard binary classification metrics from probabilities.\"\"\"\n",
|
110 |
+
" y_true_np, y_pred_proba_np = np.array(y_true), np.array(y_pred_proba)\n",
|
111 |
+
" is_multiclass = len(np.unique(y_true_np)) > 1\n",
|
112 |
+
" auc_roc = roc_auc_score(y_true_np, y_pred_proba_np) if is_multiclass else np.nan\n",
|
113 |
+
" auc_pr = average_precision_score(y_true_np, y_pred_proba_np) if is_multiclass else np.nan\n",
|
114 |
+
" return {\n",
|
115 |
+
" 'AUC_ROC': auc_roc, 'AUC_PR': auc_pr, 'LOG_LOSS': log_loss(y_true_np, y_pred_proba_np),\n",
|
116 |
+
" 'BRIER_SCORE': brier_score_loss(y_true_np, y_pred_proba_np),\n",
|
117 |
+
" 'AVG_Y_PRED_PROBA': np.mean(y_pred_proba_np), 'AVG_Y_TRUE': np.mean(y_true_np)\n",
|
118 |
+
" }\n",
|
119 |
+
"\n",
|
120 |
+
"def calculate_multiclass_classification_metrics(y_true_encoded, y_pred_labels, y_pred_proba, le_classes):\n",
|
121 |
+
" \"\"\"Calculates a dictionary of standard multi-class classification metrics.\"\"\"\n",
|
122 |
+
" num_samples, num_classes = len(y_true_encoded), len(le_classes)\n",
|
123 |
+
" metrics = {\n",
|
124 |
+
" 'ACCURACY': accuracy_score(y_true_encoded, y_pred_labels),\n",
|
125 |
+
" 'LOG_LOSS': log_loss(y_true_encoded, y_pred_proba, labels=np.arange(num_classes))\n",
|
126 |
+
" }\n",
|
127 |
+
" per_class_details, all_brier_scores = [], []\n",
|
128 |
+
" if num_samples > 0 and num_classes > 0:\n",
|
129 |
+
" for i in range(num_classes):\n",
|
130 |
+
" class_name, true_class_binary = le_classes[i], (y_true_encoded == i).astype(int)\n",
|
131 |
+
" pred_proba_for_class, true_proportion_class = y_pred_proba[:, i], np.mean(true_class_binary)\n",
|
132 |
+
" brier_score_class = brier_score_loss(true_class_binary, pred_proba_for_class) if len(np.unique(true_class_binary)) > 1 else np.nan\n",
|
133 |
+
" all_brier_scores.append(brier_score_class)\n",
|
134 |
+
" per_class_details.append({\n",
|
135 |
+
" \"class_name\": class_name, \"avg_pred_proba\": np.mean(pred_proba_for_class),\n",
|
136 |
+
" \"true_proportion\": true_proportion_class,\n",
|
137 |
+
" \"proba_ratio\": np.mean(pred_proba_for_class) / true_proportion_class if true_proportion_class > 0 else np.nan,\n",
|
138 |
+
" \"brier_score\": brier_score_class\n",
|
139 |
+
" })\n",
|
140 |
+
" metrics['per_class_details'] = per_class_details\n",
|
141 |
+
" valid_brier_scores = [s for s in all_brier_scores if not np.isnan(s)]\n",
|
142 |
+
" metrics['BRIER_SCORE_MACRO_AVG'] = np.mean(valid_brier_scores) if valid_brier_scores else np.nan\n",
|
143 |
+
" return metrics\n",
|
144 |
+
"\n",
|
145 |
+
"def log_metrics_to_snowflake(session, run_id, model_id, data_source, year_nbr, model_target, metrics_dict, table_name):\n",
|
146 |
+
" \"\"\"Constructs a DataFrame from a metrics dictionary and appends it to a Snowflake table.\"\"\"\n",
|
147 |
+
" metrics_schema = StructType([\n",
|
148 |
+
" StructField(\"RUN_ID\", StringType(), nullable=False), StructField(\"MODEL_ID\", StringType(), nullable=True),\n",
|
149 |
+
" StructField(\"DATA_SOURCE\", StringType(), nullable=True), StructField(\"YEAR_NBR\", LongType(), nullable=True),\n",
|
150 |
+
" StructField(\"MODEL_TARGET\", StringType(), nullable=True), StructField(\"EVAL_TIMESTAMP\", TimestampType(), nullable=False),\n",
|
151 |
+
" StructField(\"N_SAMPLES\", LongType(), nullable=True), StructField(\"R2\", FloatType(), nullable=True),\n",
|
152 |
+
" StructField(\"MAE\", FloatType(), nullable=True), StructField(\"MSE\", FloatType(), nullable=True),\n",
|
153 |
+
" StructField(\"PRED_RATIO\", FloatType(), nullable=True), StructField(\"MAE_PERCENT\", FloatType(), nullable=True),\n",
|
154 |
+
" StructField(\"AUC_ROC\", FloatType(), nullable=True), StructField(\"AUC_PR\", FloatType(), nullable=True),\n",
|
155 |
+
" StructField(\"LOG_LOSS\", FloatType(), nullable=True), StructField(\"BRIER_SCORE\", FloatType(), nullable=True),\n",
|
156 |
+
" StructField(\"ACCURACY\", FloatType(), nullable=True), StructField(\"AVG_Y_PRED\", FloatType(), nullable=True),\n",
|
157 |
+
" StructField(\"AVG_Y_TRUE\", FloatType(), nullable=True)\n",
|
158 |
+
" ])\n",
|
159 |
+
" avg_y_pred = metrics_dict.get('AVG_Y_PRED', metrics_dict.get('AVG_Y_PRED_PROBA'))\n",
|
160 |
+
" payload = {\n",
|
161 |
+
" \"RUN_ID\": run_id, \"MODEL_ID\": model_id, \"DATA_SOURCE\": data_source,\n",
|
162 |
+
" \"YEAR_NBR\": int(year_nbr) if pd.notna(year_nbr) else None, \"MODEL_TARGET\": model_target,\n",
|
163 |
+
" \"EVAL_TIMESTAMP\": datetime.utcnow(), \"N_SAMPLES\": metrics_dict.get('n_samples'), \"R2\": metrics_dict.get('R2'),\n",
|
164 |
+
" \"MAE\": metrics_dict.get('MAE'), \"MSE\": metrics_dict.get('MSE'), \"PRED_RATIO\": metrics_dict.get('PRED_RATIO'),\n",
|
165 |
+
" \"MAE_PERCENT\": metrics_dict.get('MAE_PERCENT'), \"AUC_ROC\": metrics_dict.get('AUC_ROC'),\n",
|
166 |
+
" \"AUC_PR\": metrics_dict.get('AUC_PR'), \"LOG_LOSS\": metrics_dict.get('LOG_LOSS'),\n",
|
167 |
+
" \"BRIER_SCORE\": metrics_dict.get('BRIER_SCORE', metrics_dict.get('BRIER_SCORE_MACRO_AVG')),\n",
|
168 |
+
" \"ACCURACY\": metrics_dict.get('ACCURACY'), \"AVG_Y_PRED\": avg_y_pred, \"AVG_Y_TRUE\": metrics_dict.get('AVG_Y_TRUE')\n",
|
169 |
+
" }\n",
|
170 |
+
" dfm = pd.DataFrame([payload]).replace({np.nan: None, pd.NaT: None})\n",
|
171 |
+
" try:\n",
|
172 |
+
" column_order = [field.name for field in metrics_schema.fields]\n",
|
173 |
+
" dfm_reordered = dfm[column_order]\n",
|
174 |
+
" snowpark_df = session.create_dataframe(dfm_reordered.values.tolist(), schema=metrics_schema)\n",
|
175 |
+
" snowpark_df.write.mode(\"append\").save_as_table(table_name)\n",
|
176 |
+
" print(f\"✅ Logged metrics for {model_target} ({data_source}, {year_nbr}) to {table_name}.\")\n",
|
177 |
+
" except Exception as e:\n",
|
178 |
+
" print(f\"Error logging metrics for {model_target}: {e}\\nPayload: {dfm.to_dict('records')}\")\n",
|
179 |
+
"\n",
|
180 |
+
"# =============================================================================\n",
|
181 |
+
"# 2. MAIN EXECUTION\n",
|
182 |
+
"# =============================================================================\n",
|
183 |
+
"def main(session: Session):\n",
|
184 |
+
" print(f\"--- Starting Inpatient Prediction Pipeline ---\\nRun ID: {RUN_ID}\")\n",
|
185 |
+
" session.use_database(DATABASE)\n",
|
186 |
+
" session.use_schema(SCHEMA)\n",
|
187 |
+
" print(f\"Session context set to DATABASE: {DATABASE}, SCHEMA: {SCHEMA}\")\n",
|
188 |
+
"\n",
|
189 |
+
" # --- Stage 1: Load Model & Artifacts ---\n",
|
190 |
+
" print(f\"\\n--- Stage 1: Loading artifacts from stage to {LOCAL_ARTIFACT_DIR} ---\")\n",
|
191 |
+
" os.makedirs(LOCAL_ARTIFACT_DIR, exist_ok=True)\n",
|
192 |
+
" df_training_fill_rates = pd.DataFrame() # Initialize as empty DataFrame\n",
|
193 |
+
"\n",
|
194 |
+
" try:\n",
|
195 |
+
" # The model bundle is a required artifact for the script to run.\n",
|
196 |
+
" session.file.get(MODEL_STAGE_PATH, LOCAL_ARTIFACT_DIR)\n",
|
197 |
+
" local_model_path = os.path.join(LOCAL_ARTIFACT_DIR, MODEL_FILE_NAME_IN_STAGE)\n",
|
198 |
+
" with gzip.open(local_model_path, \"rb\") as f:\n",
|
199 |
+
" models_bundle = pickle.load(f)\n",
|
200 |
+
" print(\"✅ Model bundle loaded successfully.\")\n",
|
201 |
+
"\n",
|
202 |
+
" # The fill rate file is optional. It is loaded only if the feature is enabled.\n",
|
203 |
+
" if ENABLE_FILL_RATE_COMPARISON:\n",
|
204 |
+
" try:\n",
|
205 |
+
" session.file.get(FILL_RATE_STAGE_PATH, LOCAL_ARTIFACT_DIR)\n",
|
206 |
+
" local_fr_path = os.path.join(LOCAL_ARTIFACT_DIR, FILL_RATE_FILE_NAME_IN_STAGE)\n",
|
207 |
+
" with gzip.open(local_fr_path, \"rb\") as f:\n",
|
208 |
+
" df_training_fill_rates = pd.read_csv(f)\n",
|
209 |
+
" print(\"✅ Training fill rates loaded successfully.\")\n",
|
210 |
+
" except Exception as e:\n",
|
211 |
+
" print(f\"WARNING: Could not load fill rate file from '{FILL_RATE_STAGE_PATH}'. \"\n",
|
212 |
+
" f\"Fill rate comparison will be skipped. Error: {e}\")\n",
|
213 |
+
"\n",
|
214 |
+
" except Exception as e:\n",
|
215 |
+
" print(f\"CRITICAL ERROR: Failed to load the required model bundle from stage. Error: {e}\")\n",
|
216 |
+
" return\n",
|
217 |
+
"\n",
|
218 |
+
" los_model, readmission_model, discharge_model = models_bundle['los_model'], models_bundle['readmission_model'], models_bundle['discharge_model']\n",
|
219 |
+
" los_features, readmission_features, discharge_features = models_bundle['feature_columns_los'], models_bundle['feature_columns_readmission'], models_bundle['feature_columns_discharge']\n",
|
220 |
+
" le_discharge = models_bundle.get('le_discharge')\n",
|
221 |
+
" model_version = models_bundle.get('model_run_id', MODEL_FILE_NAME_IN_STAGE)\n",
|
222 |
+
" print(f\" - Using Model Version (from training run): {model_version}\")\n",
|
223 |
+
"\n",
|
224 |
+
" # --- Stage 2: Load New Data for Prediction ---\n",
|
225 |
+
" print(f\"\\n--- Stage 2: Loading data from table: {FULL_INPUT_TABLE} ---\")\n",
|
226 |
+
" try:\n",
|
227 |
+
" query = session.table(FULL_INPUT_TABLE)\n",
|
228 |
+
" if ROW_LIMIT:\n",
|
229 |
+
" query = query.limit(ROW_LIMIT)\n",
|
230 |
+
" df_new_data_pd = query.to_pandas()\n",
|
231 |
+
" print(f\"Loaded {len(df_new_data_pd)} rows for prediction.\")\n",
|
232 |
+
" if df_new_data_pd.empty:\n",
|
233 |
+
" print(\"WARNING: Input data is empty. Exiting script.\")\n",
|
234 |
+
" return\n",
|
235 |
+
" df_new_data_pd.columns = [col.upper() for col in df_new_data_pd.columns]\n",
|
236 |
+
" except Exception as e:\n",
|
237 |
+
" print(f\"CRITICAL ERROR: Failed to load data from {FULL_INPUT_TABLE}. Error: {e}\")\n",
|
238 |
+
" return\n",
|
239 |
+
" \n",
|
240 |
+
" # --- Stage 3: Preprocess New Data ---\n",
|
241 |
+
" print(\"\\n--- Stage 3: Preprocessing data ---\")\n",
|
242 |
+
" df_new_data_pd_lower = df_new_data_pd.copy()\n",
|
243 |
+
" df_new_data_pd_lower.columns = df_new_data_pd_lower.columns.str.lower()\n",
|
244 |
+
" categorical_cols = ['sex', 'state', 'race', 'ms_drg_code', 'ccsr_cat']\n",
|
245 |
+
" cols_to_encode = [col for col in categorical_cols if col in df_new_data_pd_lower.columns]\n",
|
246 |
+
" if cols_to_encode:\n",
|
247 |
+
" df_new_data_encoded = pd.get_dummies(df_new_data_pd_lower, columns=cols_to_encode, dummy_na=False)\n",
|
248 |
+
" else:\n",
|
249 |
+
" df_new_data_encoded = df_new_data_pd_lower.copy()\n",
|
250 |
+
" print(f\"Data preprocessed. Total features after encoding: {len(df_new_data_encoded.columns)}\")\n",
|
251 |
+
"\n",
|
252 |
+
" # --- Stage 4: Calculate & Compare Feature Fill Rates ---\n",
|
253 |
+
" print(\"\\n--- Stage 4: Comparing input data fill rate to training data fill rate ---\")\n",
|
254 |
+
" if ENABLE_FILL_RATE_COMPARISON:\n",
|
255 |
+
" # This block executes only if the feature is enabled in the configuration.\n",
|
256 |
+
" if not df_training_fill_rates.empty:\n",
|
257 |
+
" # Calculate the non-null rate for all columns.\n",
|
258 |
+
" input_fill_rate_series = (df_new_data_encoded.notna().sum() / len(df_new_data_encoded))\n",
|
259 |
+
"\n",
|
260 |
+
" # Identify binary/categorical features to calculate positive rate instead of non-null rate.\n",
|
261 |
+
" binary_prefixes = ('hcc_', 'cms_', 'cond_') + tuple(f'{col}_' for col in cols_to_encode)\n",
|
262 |
+
" binary_feature_names = [\n",
|
263 |
+
" col for col in df_new_data_encoded.columns \n",
|
264 |
+
" if col.lower().startswith(binary_prefixes)\n",
|
265 |
+
" ]\n",
|
266 |
+
" print(f\"Identified {len(binary_feature_names)} binary-like features for positive rate calculation.\")\n",
|
267 |
+
" \n",
|
268 |
+
" # For these binary columns, calculate the positive rate (mean) and update the series.\n",
|
269 |
+
" if binary_feature_names:\n",
|
270 |
+
" existing_binary_features = [c for c in binary_feature_names if c in df_new_data_encoded.columns]\n",
|
271 |
+
" if existing_binary_features:\n",
|
272 |
+
" positive_rates = df_new_data_encoded[existing_binary_features].mean()\n",
|
273 |
+
" input_fill_rate_series.update(positive_rates)\n",
|
274 |
+
"\n",
|
275 |
+
" df_input_fill_rates = input_fill_rate_series.reset_index()\n",
|
276 |
+
" df_input_fill_rates.columns = ['FEATURE', 'INPUT_FILL_RATE']\n",
|
277 |
+
"\n",
|
278 |
+
" # Prepare training data fill rates for comparison.\n",
|
279 |
+
" df_training_fill_rates.columns = df_training_fill_rates.columns.str.upper()\n",
|
280 |
+
" df_training_fill_rates = df_training_fill_rates.rename(columns={'FEATURE_NAME': 'FEATURE'})\n",
|
281 |
+
" if 'POSITIVE_RATE_PERCENT' in df_training_fill_rates.columns:\n",
|
282 |
+
" df_training_fill_rates['TRAINING_FILL_RATE'] = df_training_fill_rates['POSITIVE_RATE_PERCENT'] / 100.0\n",
|
283 |
+
"\n",
|
284 |
+
" # Merge input and training fill rates on the feature name.\n",
|
285 |
+
" df_training_fill_rates['FEATURE'] = df_training_fill_rates['FEATURE'].str.upper()\n",
|
286 |
+
" df_input_fill_rates['FEATURE'] = df_input_fill_rates['FEATURE'].str.upper()\n",
|
287 |
+
"\n",
|
288 |
+
" df_comparison = pd.merge(\n",
|
289 |
+
" df_training_fill_rates[['FEATURE', 'TRAINING_FILL_RATE']],\n",
|
290 |
+
" df_input_fill_rates, on='FEATURE', how='outer'\n",
|
291 |
+
" )\n",
|
292 |
+
" df_comparison['RUN_ID'] = RUN_ID\n",
|
293 |
+
" df_comparison['LAST_RUN'] = pd.Timestamp(datetime.utcnow())\n",
|
294 |
+
" df_comparison['FILL_RATE_DIFFERENCE'] = df_comparison['INPUT_FILL_RATE'] - df_comparison['TRAINING_FILL_RATE']\n",
|
295 |
+
" df_comparison = df_comparison[['RUN_ID', 'FEATURE', 'TRAINING_FILL_RATE', 'INPUT_FILL_RATE', 'FILL_RATE_DIFFERENCE', 'LAST_RUN']]\n",
|
296 |
+
" \n",
|
297 |
+
" # Save the comparison table to Snowflake.\n",
|
298 |
+
" session.write_pandas(df_comparison, FULL_OUTPUT_FILL_RATE_COMPARISON_TABLE, auto_create_table=True, overwrite=True)\n",
|
299 |
+
" print(f\"✅ Successfully saved fill rate comparison to {FULL_OUTPUT_FILL_RATE_COMPARISON_TABLE}\")\n",
|
300 |
+
" else:\n",
|
301 |
+
" print(\"WARNING: Training fill rate data not available. Skipping comparison.\")\n",
|
302 |
+
" else:\n",
|
303 |
+
" print(\"Fill rate comparison is disabled by configuration. Skipping.\")\n",
|
304 |
+
" \n",
|
305 |
+
" # --- Stage 5: Generate Predictions & Save ---\n",
|
306 |
+
" print(\"\\n--- Stage 5: Generating and saving predictions ---\")\n",
|
307 |
+
" predictions_df = pd.DataFrame({\n",
|
308 |
+
" \"ENCOUNTER_ID\": df_new_data_encoded['encounter_id'],\n",
|
309 |
+
" \"LENGTH_OF_STAY_PRED\": los_model.predict(df_new_data_encoded.reindex(columns=los_features, fill_value=0)),\n",
|
310 |
+
" \"READMISSION_PRED\": readmission_model.predict_proba(df_new_data_encoded.reindex(columns=readmission_features, fill_value=0))[:, 1],\n",
|
311 |
+
" \"LAST_RUN\": datetime.utcnow()\n",
|
312 |
+
" })\n",
|
313 |
+
" if le_discharge is not None:\n",
|
314 |
+
" discharge_probas = discharge_model.predict_proba(df_new_data_encoded.reindex(columns=discharge_features, fill_value=0))\n",
|
315 |
+
" for i, class_label in enumerate(le_discharge.classes_):\n",
|
316 |
+
" col_name = f\"DISCHARGE_PRED_PROBA_{class_label.upper().replace(' ', '_').replace('/', '_')}\"\n",
|
317 |
+
" predictions_df[col_name] = discharge_probas[:, i]\n",
|
318 |
+
" session.write_pandas(predictions_df, FULL_OUTPUT_PREDICTIONS_TABLE, auto_create_table=True, overwrite=True)\n",
|
319 |
+
" print(f\"✅ Successfully saved {len(predictions_df)} predictions to {FULL_OUTPUT_PREDICTIONS_TABLE}\")\n",
|
320 |
+
"\n",
|
321 |
+
" # --- Stage 6: Calculate and Save Evaluation Metrics ---\n",
|
322 |
+
" print(\"\\n--- Stage 6: Calculating and saving evaluation metrics ---\")\n",
|
323 |
+
" actual_cols_to_get = ['ENCOUNTER_ID', 'DATA_SOURCE', 'YEAR_NBR', 'LENGTH_OF_STAY', 'READMISSION_NUMERATOR', 'READMISSION_DENOMINATOR', 'DISCHARGE_LOCATION']\n",
|
324 |
+
" actual_cols = [col for col in actual_cols_to_get if col in df_new_data_pd.columns]\n",
|
325 |
+
" eval_df = pd.merge(predictions_df, df_new_data_pd[actual_cols], on=\"ENCOUNTER_ID\", how=\"left\")\n",
|
326 |
+
" eval_df.columns = eval_df.columns.str.lower()\n",
|
327 |
+
" \n",
|
328 |
+
" ACTUAL_LOS_COL_LOWER, ACTUAL_READMISSION_NUM_LOWER, ACTUAL_READMISSION_DENOM_LOWER, ACTUAL_DISCHARGE_COL_LOWER = \"length_of_stay\", \"readmission_numerator\", \"readmission_denominator\", \"discharge_location\"\n",
|
329 |
+
" groups_to_iterate = eval_df.groupby(['data_source', 'year_nbr'], dropna=False) if 'data_source' in eval_df.columns and 'year_nbr' in eval_df.columns else [((\"Overall\", -1), eval_df)]\n",
|
330 |
+
"\n",
|
331 |
+
" for group_key, group_df in groups_to_iterate:\n",
|
332 |
+
" data_source_val, year_nbr_val = group_key\n",
|
333 |
+
" print(f\"\\n--- Calculating metrics for group: {data_source_val}, {year_nbr_val} ---\")\n",
|
334 |
+
" if ACTUAL_LOS_COL_LOWER in group_df.columns and not group_df[ACTUAL_LOS_COL_LOWER].isnull().all():\n",
|
335 |
+
" metrics = calculate_regression_metrics(group_df[ACTUAL_LOS_COL_LOWER], group_df[\"length_of_stay_pred\"])\n",
|
336 |
+
" metrics['n_samples'] = len(group_df)\n",
|
337 |
+
" log_metrics_to_snowflake(session, RUN_ID, model_version, data_source_val, year_nbr_val, \"LENGTH_OF_STAY\", metrics, FULL_METRICS_TABLE)\n",
|
338 |
+
" if ACTUAL_READMISSION_NUM_LOWER in group_df.columns and ACTUAL_READMISSION_DENOM_LOWER in group_df.columns:\n",
|
339 |
+
" readmission_eval_df = group_df[group_df[ACTUAL_READMISSION_DENOM_LOWER] == 1].copy()\n",
|
340 |
+
" if not readmission_eval_df.empty and not readmission_eval_df[ACTUAL_READMISSION_NUM_LOWER].isnull().all():\n",
|
341 |
+
" metrics = calculate_binary_classification_proba_metrics(readmission_eval_df[ACTUAL_READMISSION_NUM_LOWER], readmission_eval_df[\"readmission_pred\"])\n",
|
342 |
+
" metrics['n_samples'] = len(readmission_eval_df)\n",
|
343 |
+
" log_metrics_to_snowflake(session, RUN_ID, model_version, data_source_val, year_nbr_val, \"READMISSION\", metrics, FULL_METRICS_TABLE)\n",
|
344 |
+
" if ACTUAL_DISCHARGE_COL_LOWER in group_df.columns and le_discharge is not None:\n",
|
345 |
+
" group_input = df_new_data_encoded.loc[group_df.index]\n",
|
346 |
+
" y_true_labels = group_df[ACTUAL_DISCHARGE_COL_LOWER]\n",
|
347 |
+
" known_mask = y_true_labels.isin(le_discharge.classes_)\n",
|
348 |
+
" y_true_enc = le_discharge.transform(y_true_labels[known_mask])\n",
|
349 |
+
" group_probas = discharge_model.predict_proba(group_input[known_mask].reindex(columns=discharge_features, fill_value=0))\n",
|
350 |
+
" group_preds_enc = discharge_model.predict(group_input[known_mask].reindex(columns=discharge_features, fill_value=0))\n",
|
351 |
+
" metrics = calculate_multiclass_classification_metrics(y_true_enc, group_preds_enc, group_probas, le_discharge.classes_)\n",
|
352 |
+
" metrics['n_samples'] = len(y_true_enc)\n",
|
353 |
+
" log_metrics_to_snowflake(session, RUN_ID, model_version, data_source_val, year_nbr_val, \"DISCHARGE_LOCATION_OVERALL\", metrics, FULL_METRICS_TABLE)\n",
|
354 |
+
" for detail in metrics.get('per_class_details', []):\n",
|
355 |
+
" log_metrics_to_snowflake(session, RUN_ID, model_version, data_source_val, year_nbr_val, f\"DISCHARGE_LOCATION_Class_{detail['class_name']}\", detail, FULL_METRICS_TABLE)\n",
|
356 |
+
" print(\"\\n✅ Script finished.\")\n",
|
357 |
+
"\n",
|
358 |
+
"if __name__ == \"__main__\":\n",
|
359 |
+
" try:\n",
|
360 |
+
" snowpark_session = get_active_session()\n",
|
361 |
+
" print(\"Successfully retrieved active Snowpark session.\")\n",
|
362 |
+
" except SnowparkClientException:\n",
|
363 |
+
" print(\"No active session. Creating a new session from local credentials...\")\n",
|
364 |
+
" snowpark_session = Session.builder.create()\n",
|
365 |
+
" main(snowpark_session)"
|
366 |
+
]
|
367 |
+
}
|
368 |
+
],
|
369 |
+
"metadata": {
|
370 |
+
"kernelspec": {
|
371 |
+
"display_name": "Streamlit Notebook",
|
372 |
+
"name": "streamlit"
|
373 |
+
},
|
374 |
+
"lastEditStatus": {
|
375 |
+
"authorEmail": "[email protected]",
|
376 |
+
"authorId": "374530764978",
|
377 |
+
"authorName": "BRAD",
|
378 |
+
"lastEditTime": 1750882358404,
|
379 |
+
"notebookId": "7e7pzs6ti4k6chxub4nt",
|
380 |
+
"sessionId": "56467df4-1029-4269-83ed-9a238cb180f6"
|
381 |
+
}
|
382 |
+
},
|
383 |
+
"nbformat": 4,
|
384 |
+
"nbformat_minor": 5
|
385 |
+
}
|