Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,1087 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import joblib
|
4 |
+
import shap
|
5 |
+
import json
|
6 |
+
import plotly
|
7 |
+
import plotly.graph_objs as go
|
8 |
+
import plotly.express as px
|
9 |
+
from flask import Flask, render_template, request, jsonify
|
10 |
+
import math
|
11 |
+
import networkx as nx
|
12 |
+
import traceback # For detailed error logging
|
13 |
+
import os # For environment variables
|
14 |
+
|
15 |
+
app = Flask(__name__)
|
16 |
+
|
17 |
+
# --- Load Models ---
|
18 |
+
MODEL_PATH_NA = 'non-adherence_XAI.pkl'
|
19 |
+
MODEL_PATH_R = 'readmission_XAI.pkl'
|
20 |
+
try:
|
21 |
+
model_na = joblib.load(MODEL_PATH_NA)
|
22 |
+
model_r = joblib.load(MODEL_PATH_R)
|
23 |
+
print("Models loaded successfully.")
|
24 |
+
if not (hasattr(model_na, 'predict') and not hasattr(model_na, 'predict_proba')):
|
25 |
+
print(f"Warning: Model {MODEL_PATH_NA} might not be a regressor.")
|
26 |
+
if not (hasattr(model_r, 'predict') and not hasattr(model_r, 'predict_proba')):
|
27 |
+
print(f"Warning: Model {MODEL_PATH_R} might not be a regressor.")
|
28 |
+
|
29 |
+
except FileNotFoundError as e:
|
30 |
+
print(f"FATAL ERROR: Model file not found: {e}. Make sure the .pkl files are in the correct directory.")
|
31 |
+
print("Attempted paths:", os.path.abspath(MODEL_PATH_NA), os.path.abspath(MODEL_PATH_R))
|
32 |
+
exit()
|
33 |
+
except Exception as e:
|
34 |
+
print(f"FATAL ERROR: An unexpected error occurred loading models: {e}")
|
35 |
+
traceback.print_exc()
|
36 |
+
exit()
|
37 |
+
|
38 |
+
# --- Mappings ---
|
39 |
+
gender_map = {'Female': 0, 'Male': 1}
|
40 |
+
why_map = {
|
41 |
+
'Bone': 0, 'Brain': 1, 'Heart': 2, 'Infection': 3,
|
42 |
+
'Lung': 4, 'Stomach': 5, 'Surgery': 6
|
43 |
+
}
|
44 |
+
yes_no_map = {'Yes': 1, 'No': 0}
|
45 |
+
reverse_gender_map = {v: k for k, v in gender_map.items()}
|
46 |
+
reverse_why_map = {v: k for k, v in why_map.items()}
|
47 |
+
reverse_yes_no_map = {v: k for k, v in yes_no_map.items()}
|
48 |
+
|
49 |
+
# --- Feature Order (Crucial for model input) ---
|
50 |
+
feature_order = [
|
51 |
+
'Age', 'Gender', 'Why in Hospital', 'Hospital Days', 'Was in ICU (1=Yes)',
|
52 |
+
'ICU Days', 'Number of Medicines', 'Cost per Medicine (₹)', 'Days Medicine Lasts',
|
53 |
+
'Total Dosage per Day (mg)', 'Total Pills Given', 'Medicine Availability (0-1)',
|
54 |
+
'Took Medicine Day 1 (1=Yes)', 'Took Medicine Day 2 (1=Yes)', 'Took Medicine Day 3 (1=Yes)'
|
55 |
+
]
|
56 |
+
|
57 |
+
# --- Risk Level Logic ---
|
58 |
+
def get_risk_level(score):
|
59 |
+
"""Categorizes risk score into Low, Medium, High."""
|
60 |
+
try:
|
61 |
+
clamped_score = max(0.0, min(1.0, float(score)))
|
62 |
+
except (ValueError, TypeError):
|
63 |
+
print(f"Warning: Invalid score '{score}' received. Defaulting to 0.")
|
64 |
+
clamped_score = 0.0
|
65 |
+
|
66 |
+
percentage = round(clamped_score * 100)
|
67 |
+
if clamped_score < 0.3:
|
68 |
+
return {'level': 'Low', 'color': 'green', 'percentage': f"{percentage}%"}
|
69 |
+
elif clamped_score < 0.7:
|
70 |
+
return {'level': 'Medium', 'color': 'yellow', 'percentage': f"{percentage}%"}
|
71 |
+
else:
|
72 |
+
return {'level': 'High', 'color': 'red', 'percentage': f"{percentage}%"}
|
73 |
+
|
74 |
+
# --- Feature Metadata for UI ---
|
75 |
+
# (Keep the feature_info dictionary as it was in the previous version)
|
76 |
+
feature_info = {
|
77 |
+
'Age': {
|
78 |
+
'description': 'Patient age in years',
|
79 |
+
'question': 'How old is the patient?',
|
80 |
+
'help_text': 'Age is a significant factor in both medication adherence and hospital readmission risk.',
|
81 |
+
'ideal_range': '18-100', 'type': 'number'
|
82 |
+
},
|
83 |
+
'Gender': {
|
84 |
+
'description': 'Patient gender',
|
85 |
+
'question': 'What is the patient\'s gender?',
|
86 |
+
'help_text': 'Gender can influence medication adherence patterns and readmission risk for certain conditions.',
|
87 |
+
'options': list(gender_map.keys()), 'type': 'select'
|
88 |
+
},
|
89 |
+
'Why in Hospital': {
|
90 |
+
'description': 'Primary reason for hospitalization',
|
91 |
+
'question': 'What is the primary reason for hospitalization?',
|
92 |
+
'help_text': 'Different conditions have varying impacts on adherence and readmission patterns.',
|
93 |
+
'options': list(why_map.keys()), 'type': 'select'
|
94 |
+
},
|
95 |
+
'Hospital Days': {
|
96 |
+
'description': 'Total days spent in hospital during this admission',
|
97 |
+
'question': 'How many days did the patient spend in the hospital?',
|
98 |
+
'help_text': 'Longer hospital stays often correlate with more complex cases and higher readmission risks.',
|
99 |
+
'ideal_range': '1-30', 'type': 'number'
|
100 |
+
},
|
101 |
+
'Was in ICU (1=Yes)': {
|
102 |
+
'description': 'Whether patient spent time in ICU',
|
103 |
+
'question': 'Did the patient spend time in the ICU?',
|
104 |
+
'help_text': 'ICU stays indicate higher severity and may impact post-discharge outcomes.',
|
105 |
+
'options': list(yes_no_map.keys()), 'type': 'select'
|
106 |
+
},
|
107 |
+
'ICU Days': {
|
108 |
+
'description': 'Total days spent in ICU (if applicable)',
|
109 |
+
'question': 'How many days did the patient spend in the ICU? (Enter 0 if not in ICU)',
|
110 |
+
'help_text': 'Longer ICU stays typically indicate more severe conditions requiring careful post-discharge planning.',
|
111 |
+
'ideal_range': '0-15', 'type': 'number'
|
112 |
+
},
|
113 |
+
'Number of Medicines': {
|
114 |
+
'description': 'Total number of different medications prescribed',
|
115 |
+
'question': 'How many different medications is the patient prescribed?',
|
116 |
+
'help_text': 'Higher medication counts increase complexity and can lead to reduced adherence.',
|
117 |
+
'ideal_range': '1-12', 'type': 'number'
|
118 |
+
},
|
119 |
+
'Cost per Medicine (₹)': {
|
120 |
+
'description': 'Average cost per medication in rupees',
|
121 |
+
'question': 'What is the average cost per medication (in ₹)?',
|
122 |
+
'help_text': 'Higher medication costs can impact adherence due to financial constraints.',
|
123 |
+
'ideal_range': '10-1000', 'type': 'number'
|
124 |
+
},
|
125 |
+
'Days Medicine Lasts': {
|
126 |
+
'description': 'Number of days the prescribed medication will last',
|
127 |
+
'question': 'How many days will the prescribed medication last?',
|
128 |
+
'help_text': 'Longer durations between refills can affect adherence patterns.',
|
129 |
+
'ideal_range': '7-90', 'type': 'number'
|
130 |
+
},
|
131 |
+
'Total Dosage per Day (mg)': {
|
132 |
+
'description': 'Total medication dosage per day in milligrams',
|
133 |
+
'question': 'What is the total medication dosage per day (in mg)?',
|
134 |
+
'help_text': 'Higher daily dosages may indicate more severe conditions and can affect adherence.',
|
135 |
+
'ideal_range': '5-500', 'type': 'number'
|
136 |
+
},
|
137 |
+
'Total Pills Given': {
|
138 |
+
'description': 'Total number of pills provided at discharge',
|
139 |
+
'question': 'How many total pills were given to the patient?',
|
140 |
+
'help_text': 'Pill burden is a known factor in medication adherence.',
|
141 |
+
'ideal_range': '10-300', 'type': 'number'
|
142 |
+
},
|
143 |
+
'Medicine Availability (0-1)': {
|
144 |
+
'description': 'Availability score of prescribed medication (0=low, 1=high)',
|
145 |
+
'question': 'How available is the medication (0=low, 1=high)?',
|
146 |
+
'help_text': 'Limited availability can significantly impact medication adherence.',
|
147 |
+
'ideal_range': '0-1', 'type': 'number', 'step': '0.01' # Specify step for float input
|
148 |
+
},
|
149 |
+
'Took Medicine Day 1 (1=Yes)': {
|
150 |
+
'description': 'Whether patient took medication on day 1 post-discharge',
|
151 |
+
'question': 'Did the patient take their medication on day 1 after discharge?',
|
152 |
+
'help_text': 'Early adherence patterns are strong predictors of overall medication adherence.',
|
153 |
+
'options': list(yes_no_map.keys()), 'type': 'select'
|
154 |
+
},
|
155 |
+
'Took Medicine Day 2 (1=Yes)': {
|
156 |
+
'description': 'Whether patient took medication on day 2 post-discharge',
|
157 |
+
'question': 'Did the patient take their medication on day 2 after discharge?',
|
158 |
+
'help_text': 'Consistent adherence in the first days after discharge indicates better overall adherence.',
|
159 |
+
'options': list(yes_no_map.keys()), 'type': 'select'
|
160 |
+
},
|
161 |
+
'Took Medicine Day 3 (1=Yes)': {
|
162 |
+
'description': 'Whether patient took medication on day 3 post-discharge',
|
163 |
+
'question': 'Did the patient take their medication on day 3 after discharge?',
|
164 |
+
'help_text': 'Patterns established in the first few days often continue throughout treatment.',
|
165 |
+
'options': list(yes_no_map.keys()), 'type': 'select'
|
166 |
+
}
|
167 |
+
}
|
168 |
+
|
169 |
+
# --- Risk Explanations for UI ---
|
170 |
+
# (Keep the risk_explanations dictionary as it was)
|
171 |
+
risk_explanations = {
|
172 |
+
'non_adherence': {
|
173 |
+
'title': 'Medication Non-Adherence Risk',
|
174 |
+
'description': 'Medication non-adherence refers to the degree to which a patient does not follow their prescription medication regimen as directed by their healthcare provider. This includes missing doses, taking incorrect doses, or stopping treatment early.',
|
175 |
+
'levels': {
|
176 |
+
'Low': 'Patient is likely to follow medication regimen as prescribed with minimal intervention needed.',
|
177 |
+
'Medium': 'Patient may need additional support such as reminders or follow-up calls to ensure adherence.',
|
178 |
+
'High': 'Patient is at significant risk of not taking medications as prescribed. Consider intensive interventions and close monitoring.'
|
179 |
+
},
|
180 |
+
'consequences': [
|
181 |
+
'Reduced treatment effectiveness', 'Disease progression or complications',
|
182 |
+
'Increased hospitalization rates', 'Higher healthcare costs', 'Poorer health outcomes'
|
183 |
+
],
|
184 |
+
'interventions': [
|
185 |
+
'Medication reminder systems', 'Simplified medication regimens',
|
186 |
+
'Patient education on importance of adherence', 'Regular follow-up calls',
|
187 |
+
'Addressing barriers (financial, logistical, etc.)'
|
188 |
+
]
|
189 |
+
},
|
190 |
+
'readmission': {
|
191 |
+
'title': 'Hospital Readmission Risk',
|
192 |
+
'description': 'Hospital readmission risk refers to the likelihood that a patient will need to return to the hospital within a short period (typically 30 days) after being discharged.',
|
193 |
+
'levels': {
|
194 |
+
'Low': 'Patient has minimal risk factors for readmission and can likely be managed with standard follow-up care.',
|
195 |
+
'Medium': 'Patient has moderate risk of readmission and may benefit from enhanced discharge planning and follow-up.',
|
196 |
+
'High': 'Patient is at high risk for readmission and requires comprehensive discharge planning, early follow-up, and possibly home health services.'
|
197 |
+
},
|
198 |
+
'consequences': [
|
199 |
+
'Increased patient suffering', 'Higher healthcare costs', 'Potential complications',
|
200 |
+
'Disruption to patient recovery', 'Reduced hospital quality metrics'
|
201 |
+
],
|
202 |
+
'interventions': [
|
203 |
+
'Comprehensive discharge planning', 'Medication reconciliation',
|
204 |
+
'Early (within 7 days) follow-up appointments', 'Home health services when appropriate',
|
205 |
+
'Patient and caregiver education'
|
206 |
+
]
|
207 |
+
}
|
208 |
+
}
|
209 |
+
|
210 |
+
|
211 |
+
# --- SHAP Explainer Initialization ---
|
212 |
+
# (Keep the get_shap_explainer function as it was)
|
213 |
+
shap_explainers = {}
|
214 |
+
def get_shap_explainer(model_key, model):
|
215 |
+
"""Gets or creates a SHAP explainer for a given model."""
|
216 |
+
global shap_explainers
|
217 |
+
if model_key not in shap_explainers:
|
218 |
+
print(f"Initializing SHAP explainer for {model_key}...")
|
219 |
+
try:
|
220 |
+
if hasattr(model, 'feature_importances_') or 'xgboost' in str(type(model)).lower():
|
221 |
+
explainer = shap.TreeExplainer(model)
|
222 |
+
shap_explainers[model_key] = explainer
|
223 |
+
print(f"SHAP TreeExplainer for {model_key} initialized.")
|
224 |
+
else:
|
225 |
+
print(f"Warning: Model for {model_key} might not be a tree model. Using generic SHAP Explainer.")
|
226 |
+
try:
|
227 |
+
explainer = shap.Explainer(model)
|
228 |
+
shap_explainers[model_key] = explainer
|
229 |
+
print(f"Initialized generic SHAP Explainer for {model_key}.")
|
230 |
+
except Exception as gen_e:
|
231 |
+
print(f"ERROR initializing generic SHAP explainer for {model_key}: {gen_e}")
|
232 |
+
return None
|
233 |
+
except Exception as e:
|
234 |
+
print(f"ERROR initializing SHAP Explainer for {model_key}: {e}")
|
235 |
+
traceback.print_exc()
|
236 |
+
return None
|
237 |
+
return shap_explainers.get(model_key)
|
238 |
+
|
239 |
+
|
240 |
+
# --- Flask Routes ---
|
241 |
+
@app.route('/')
|
242 |
+
def home():
|
243 |
+
"""Renders the main input form page."""
|
244 |
+
return render_template('index.html', feature_info=feature_info, feature_order=feature_order, risk_explanations=risk_explanations)
|
245 |
+
|
246 |
+
@app.route('/predict', methods=['POST'])
|
247 |
+
def predict():
|
248 |
+
"""Handles prediction requests, generates explanations, and returns JSON."""
|
249 |
+
try:
|
250 |
+
data = request.get_json()
|
251 |
+
if not data:
|
252 |
+
return jsonify({'success': False, 'error': 'No data received'}), 400
|
253 |
+
print("Received data:", data)
|
254 |
+
|
255 |
+
# --- Input Processing and Validation ---
|
256 |
+
user_input = {}
|
257 |
+
# Use the more robust key mapping based on getFieldId in JS
|
258 |
+
js_key_map = {
|
259 |
+
'age': 'Age', 'gender': 'Gender', 'why-in-hospital': 'Why in Hospital',
|
260 |
+
'hospital-days': 'Hospital Days', 'was-in-icu-1yes': 'Was in ICU (1=Yes)',
|
261 |
+
'icu-days': 'ICU Days', 'number-of-medicines': 'Number of Medicines',
|
262 |
+
'cost-per-medicine-rupees': 'Cost per Medicine (₹)', 'days-medicine-lasts': 'Days Medicine Lasts',
|
263 |
+
'total-dosage-per-day-mg': 'Total Dosage per Day (mg)', 'total-pills-given': 'Total Pills Given',
|
264 |
+
'medicine-availability-0-1': 'Medicine Availability (0-1)',
|
265 |
+
'took-medicine-day-1-1yes': 'Took Medicine Day 1 (1=Yes)',
|
266 |
+
'took-medicine-day-2-1yes': 'Took Medicine Day 2 (1=Yes)',
|
267 |
+
'took-medicine-day-3-1yes': 'Took Medicine Day 3 (1=Yes)'
|
268 |
+
}
|
269 |
+
|
270 |
+
missing_features = []
|
271 |
+
invalid_features = {}
|
272 |
+
|
273 |
+
# Get 'Was in ICU' value first for conditional validation
|
274 |
+
was_icu_js_key = 'was-in-icu-1yes'
|
275 |
+
was_icu_value_str = data.get(was_icu_js_key, '').capitalize()
|
276 |
+
|
277 |
+
for js_key, feature in js_key_map.items():
|
278 |
+
if feature not in feature_order:
|
279 |
+
print(f"Warning: Feature '{feature}' from js_key_map not in expected feature_order list.")
|
280 |
+
continue
|
281 |
+
|
282 |
+
is_icu_days = (feature == 'ICU Days')
|
283 |
+
is_icu_no = (was_icu_value_str == 'No')
|
284 |
+
|
285 |
+
# Handle missing values
|
286 |
+
if js_key not in data or data[js_key] is None or str(data[js_key]).strip() == '':
|
287 |
+
# ICU days is allowed to be missing/empty only if Was in ICU is 'No'
|
288 |
+
if is_icu_days and is_icu_no:
|
289 |
+
user_input[feature] = 0.0 # Default to 0
|
290 |
+
print(f"Setting ICU Days to 0 as Was in ICU is No.")
|
291 |
+
continue
|
292 |
+
else:
|
293 |
+
missing_features.append(feature)
|
294 |
+
continue
|
295 |
+
|
296 |
+
value = data[js_key]
|
297 |
+
f_info = feature_info.get(feature)
|
298 |
+
if not f_info:
|
299 |
+
print(f"Warning: No feature_info found for '{feature}'. Skipping.")
|
300 |
+
continue
|
301 |
+
|
302 |
+
try:
|
303 |
+
# Type conversion and mapping
|
304 |
+
if f_info['type'] == 'number':
|
305 |
+
# Specific check for ICU Days if Was in ICU is Yes
|
306 |
+
if is_icu_days and not is_icu_no and float(value) < 0:
|
307 |
+
invalid_features[feature] = "ICU Days cannot be negative if patient was in ICU."
|
308 |
+
elif is_icu_days and is_icu_no and float(value) != 0:
|
309 |
+
# If they entered a non-zero value but said No ICU, force to 0
|
310 |
+
user_input[feature] = 0.0
|
311 |
+
print(f"Warning: Forcing ICU Days to 0 because Was in ICU is No, but user entered {value}.")
|
312 |
+
else:
|
313 |
+
user_input[feature] = float(value)
|
314 |
+
|
315 |
+
elif f_info['type'] == 'select':
|
316 |
+
lookup_value = value.capitalize() if feature in ['Was in ICU (1=Yes)',
|
317 |
+
'Took Medicine Day 1 (1=Yes)',
|
318 |
+
'Took Medicine Day 2 (1=Yes)',
|
319 |
+
'Took Medicine Day 3 (1=Yes)'] else value
|
320 |
+
mapped_value = None
|
321 |
+
if feature == 'Gender': mapped_value = gender_map.get(lookup_value)
|
322 |
+
elif feature == 'Why in Hospital': mapped_value = why_map.get(lookup_value)
|
323 |
+
elif feature in ['Was in ICU (1=Yes)', 'Took Medicine Day 1 (1=Yes)',
|
324 |
+
'Took Medicine Day 2 (1=Yes)', 'Took Medicine Day 3 (1=Yes)']:
|
325 |
+
mapped_value = yes_no_map.get(lookup_value)
|
326 |
+
|
327 |
+
if mapped_value is None:
|
328 |
+
invalid_features[feature] = f"Invalid option: '{value}'"
|
329 |
+
else:
|
330 |
+
user_input[feature] = mapped_value
|
331 |
+
else:
|
332 |
+
invalid_features[feature] = f"Unknown feature type: '{f_info['type']}'"
|
333 |
+
|
334 |
+
except (ValueError, TypeError) as e:
|
335 |
+
invalid_features[feature] = f"Invalid format for value '{value}': {e}"
|
336 |
+
|
337 |
+
# Handle Validation Errors
|
338 |
+
if missing_features:
|
339 |
+
return jsonify({'success': False, 'error': f"Missing input for: {', '.join(missing_features)}"}), 400
|
340 |
+
if invalid_features:
|
341 |
+
error_msg = "; ".join([f"{k}: {v}" for k,v in invalid_features.items()])
|
342 |
+
return jsonify({'success': False, 'error': f"Invalid input: {error_msg}"}), 400
|
343 |
+
|
344 |
+
# Ensure all features are present before creating DataFrame
|
345 |
+
if len(user_input) != len(feature_order):
|
346 |
+
provided = set(user_input.keys())
|
347 |
+
expected = set(feature_order)
|
348 |
+
missing = list(expected - provided)
|
349 |
+
extra = list(provided - expected)
|
350 |
+
err = f"Feature mismatch after processing. Missing: {missing}. Extra: {extra}."
|
351 |
+
print(f"ERROR: {err}")
|
352 |
+
return jsonify({'success': False, 'error': f"Internal error: Feature mismatch. Please check processing logic. Missing: {missing}"}), 500
|
353 |
+
|
354 |
+
# Create DataFrame in the correct order
|
355 |
+
df_user = pd.DataFrame([user_input], columns=feature_order)
|
356 |
+
print("Processed DataFrame for prediction:\n", df_user.to_string())
|
357 |
+
|
358 |
+
# --- Model Predictions ---
|
359 |
+
pred_na_raw = model_na.predict(df_user)[0]
|
360 |
+
pred_r_raw = model_r.predict(df_user)[0]
|
361 |
+
pred_na_score = max(0.0, min(1.0, float(pred_na_raw)))
|
362 |
+
pred_r_score = max(0.0, min(1.0, float(pred_r_raw)))
|
363 |
+
print(f"Predictions - NA Score: {pred_na_score:.4f}, R Score: {pred_r_score:.4f}")
|
364 |
+
risk_level_na = get_risk_level(pred_na_score)
|
365 |
+
risk_level_r = get_risk_level(pred_r_score)
|
366 |
+
|
367 |
+
# --- SHAP Explanations ---
|
368 |
+
# (Keep the SHAP calculation block largely the same as previous version,
|
369 |
+
# ensuring base_value_na/r and shap_error_na/r are set correctly)
|
370 |
+
shap_explainer_na = get_shap_explainer('non_adherence', model_na)
|
371 |
+
shap_explainer_r = get_shap_explainer('readmission', model_r)
|
372 |
+
shap_data_na = []
|
373 |
+
shap_data_r = []
|
374 |
+
base_value_na = None
|
375 |
+
base_value_r = None
|
376 |
+
shap_error_na = False
|
377 |
+
shap_error_r = False
|
378 |
+
|
379 |
+
# Calculate SHAP NA
|
380 |
+
if shap_explainer_na:
|
381 |
+
try:
|
382 |
+
shap_values_na = shap_explainer_na.shap_values(df_user)
|
383 |
+
shap_vec_na = None
|
384 |
+
if isinstance(shap_values_na, list): shap_vec_na = shap_values_na[0][0] # Multi-output?
|
385 |
+
elif isinstance(shap_values_na, np.ndarray) and shap_values_na.ndim == 2: shap_vec_na = shap_values_na[0]
|
386 |
+
elif isinstance(shap_values_na, np.ndarray) and shap_values_na.ndim == 1: shap_vec_na = shap_values_na
|
387 |
+
else: raise TypeError(f"Unexpected SHAP NA format: {type(shap_values_na)}")
|
388 |
+
|
389 |
+
ev_na = shap_explainer_na.expected_value
|
390 |
+
if isinstance(ev_na, (list, np.ndarray)): base_value_na = float(ev_na[0])
|
391 |
+
elif ev_na is not None: base_value_na = float(ev_na)
|
392 |
+
else: base_value_na = 0.5; print("Warning: SHAP NA expected_value is None. Using 0.5.")
|
393 |
+
|
394 |
+
if shap_vec_na is not None and len(shap_vec_na) == len(feature_order):
|
395 |
+
print(f"SHAP NA: Base={base_value_na:.4f}, Sum={np.sum(shap_vec_na):.4f}, Pred={pred_na_score:.4f}, Total={base_value_na + np.sum(shap_vec_na):.4f}")
|
396 |
+
for i, feature in enumerate(feature_order):
|
397 |
+
orig_js_key = next((k for k, v in js_key_map.items() if v == feature), None)
|
398 |
+
orig_val = data.get(orig_js_key, "N/A")
|
399 |
+
f_info = feature_info.get(feature, {})
|
400 |
+
shap_data_na.append({
|
401 |
+
'feature': feature, 'shap_value': float(shap_vec_na[i]),
|
402 |
+
'feature_value': str(orig_val), 'numeric_value': float(df_user.iloc[0, i]),
|
403 |
+
'description': f_info.get('description', ''), 'help_text': f_info.get('help_text', '')
|
404 |
+
})
|
405 |
+
shap_data_na.sort(key=lambda x: abs(x.get('shap_value', 0)), reverse=True)
|
406 |
+
else: raise ValueError(f"SHAP NA vector length mismatch or None.")
|
407 |
+
except Exception as e:
|
408 |
+
print(f"Error calculating SHAP NA: {e}"); traceback.print_exc()
|
409 |
+
shap_data_na = [{"error": "Could not calculate SHAP values for Non-Adherence."}]
|
410 |
+
base_value_na = 0.5; shap_error_na = True
|
411 |
+
else:
|
412 |
+
shap_data_na = [{"error": "SHAP explainer NA not available."}]; base_value_na = 0.5; shap_error_na = True
|
413 |
+
|
414 |
+
# Calculate SHAP R
|
415 |
+
if shap_explainer_r:
|
416 |
+
try:
|
417 |
+
shap_values_r = shap_explainer_r.shap_values(df_user)
|
418 |
+
shap_vec_r = None
|
419 |
+
if isinstance(shap_values_r, list): shap_vec_r = shap_values_r[0][0]
|
420 |
+
elif isinstance(shap_values_r, np.ndarray) and shap_values_r.ndim == 2: shap_vec_r = shap_values_r[0]
|
421 |
+
elif isinstance(shap_values_r, np.ndarray) and shap_values_r.ndim == 1: shap_vec_r = shap_values_r
|
422 |
+
else: raise TypeError(f"Unexpected SHAP R format: {type(shap_values_r)}")
|
423 |
+
|
424 |
+
ev_r = shap_explainer_r.expected_value
|
425 |
+
if isinstance(ev_r, (list, np.ndarray)): base_value_r = float(ev_r[0])
|
426 |
+
elif ev_r is not None: base_value_r = float(ev_r)
|
427 |
+
else: base_value_r = 0.5; print("Warning: SHAP R expected_value is None. Using 0.5.")
|
428 |
+
|
429 |
+
if shap_vec_r is not None and len(shap_vec_r) == len(feature_order):
|
430 |
+
print(f"SHAP R: Base={base_value_r:.4f}, Sum={np.sum(shap_vec_r):.4f}, Pred={pred_r_score:.4f}, Total={base_value_r + np.sum(shap_vec_r):.4f}")
|
431 |
+
for i, feature in enumerate(feature_order):
|
432 |
+
orig_js_key = next((k for k, v in js_key_map.items() if v == feature), None)
|
433 |
+
orig_val = data.get(orig_js_key, "N/A")
|
434 |
+
f_info = feature_info.get(feature, {})
|
435 |
+
shap_data_r.append({
|
436 |
+
'feature': feature, 'shap_value': float(shap_vec_r[i]),
|
437 |
+
'feature_value': str(orig_val), 'numeric_value': float(df_user.iloc[0, i]),
|
438 |
+
'description': f_info.get('description', ''), 'help_text': f_info.get('help_text', '')
|
439 |
+
})
|
440 |
+
shap_data_r.sort(key=lambda x: abs(x.get('shap_value', 0)), reverse=True)
|
441 |
+
else: raise ValueError(f"SHAP R vector length mismatch or None.")
|
442 |
+
except Exception as e:
|
443 |
+
print(f"Error calculating SHAP R: {e}"); traceback.print_exc()
|
444 |
+
shap_data_r = [{"error": "Could not calculate SHAP values for Readmission."}]
|
445 |
+
base_value_r = 0.5; shap_error_r = True
|
446 |
+
else:
|
447 |
+
shap_data_r = [{"error": "SHAP explainer R not available."}]; base_value_r = 0.5; shap_error_r = True
|
448 |
+
|
449 |
+
# --- Counterfactuals & Recommendations ---
|
450 |
+
cf_data_na = generate_comprehensive_counterfactuals(df_user, model_na, "Non-Adherence", shap_data_na, shap_error_na)
|
451 |
+
cf_data_r = generate_comprehensive_counterfactuals(df_user, model_r, "Readmission", shap_data_r, shap_error_r)
|
452 |
+
recommendations = generate_recommendations(shap_data_na, shap_data_r, pred_na_score, pred_r_score, shap_error_na, shap_error_r)
|
453 |
+
|
454 |
+
# --- Visualizations ---
|
455 |
+
gauges = generate_gauge_charts(pred_na_score, pred_r_score)
|
456 |
+
additional_visualizations = generate_additional_visualizations(
|
457 |
+
df_user, shap_data_na, shap_data_r, base_value_na, base_value_r, shap_error_na, shap_error_r
|
458 |
+
)
|
459 |
+
|
460 |
+
# --- Response ---
|
461 |
+
response = {
|
462 |
+
'success': True,
|
463 |
+
'predictions': {
|
464 |
+
'non_adherence': round(pred_na_score, 3),
|
465 |
+
'readmission': round(pred_r_score, 3),
|
466 |
+
'risk_level_na': risk_level_na,
|
467 |
+
'risk_level_r': risk_level_r,
|
468 |
+
'base_value_na': round(base_value_na, 3) if base_value_na is not None else None,
|
469 |
+
'base_value_r': round(base_value_r, 3) if base_value_r is not None else None
|
470 |
+
},
|
471 |
+
'explanations': {
|
472 |
+
'shap_values_na': shap_data_na,
|
473 |
+
'shap_values_r': shap_data_r,
|
474 |
+
'counterfactuals_na': cf_data_na,
|
475 |
+
'counterfactuals_r': cf_data_r,
|
476 |
+
'shap_error_na': shap_error_na,
|
477 |
+
'shap_error_r': shap_error_r
|
478 |
+
},
|
479 |
+
'recommendations': recommendations,
|
480 |
+
'visualizations': {
|
481 |
+
'gauges': gauges,
|
482 |
+
'additional': additional_visualizations
|
483 |
+
}
|
484 |
+
}
|
485 |
+
return jsonify(response)
|
486 |
+
|
487 |
+
# Error Handling remains the same as previous version
|
488 |
+
except ValueError as ve:
|
489 |
+
print(f"Value Error during prediction processing: {ve}")
|
490 |
+
traceback.print_exc()
|
491 |
+
return jsonify({'success': False, 'error': f"Invalid input: {str(ve)}"}), 400
|
492 |
+
except KeyError as ke:
|
493 |
+
print(f"Key Error during prediction processing: {ke}")
|
494 |
+
traceback.print_exc()
|
495 |
+
return jsonify({'success': False, 'error': f"Missing expected data field: {str(ke)}"}), 400
|
496 |
+
except Exception as e:
|
497 |
+
print(f"An unexpected error occurred during prediction: {e}")
|
498 |
+
traceback.print_exc()
|
499 |
+
return jsonify({'success': False, 'error': "An internal server error occurred. Please try again later."}), 500
|
500 |
+
|
501 |
+
|
502 |
+
# --- Placeholder JSON for Plots on Error ---
|
503 |
+
# (Keep create_placeholder_plot as it was)
|
504 |
+
def create_placeholder_plot(title_suffix, message="Required data not available for this chart."):
|
505 |
+
"""Creates JSON for a placeholder plot indicating data unavailability."""
|
506 |
+
return json.dumps({
|
507 |
+
'data': [],
|
508 |
+
'layout': {
|
509 |
+
'title': f'{title_suffix} (Not Generated)',
|
510 |
+
'xaxis': {'visible': False},
|
511 |
+
'yaxis': {'visible': False},
|
512 |
+
'annotations': [{
|
513 |
+
'text': message,
|
514 |
+
'xref': 'paper', 'yref': 'paper',
|
515 |
+
'x': 0.5, 'y': 0.5, 'showarrow': False,
|
516 |
+
'font': {'size': 12, 'color': '#888'}
|
517 |
+
}],
|
518 |
+
'plot_bgcolor': 'rgba(0,0,0,0)',
|
519 |
+
'paper_bgcolor': 'rgba(0,0,0,0)'
|
520 |
+
}
|
521 |
+
}, cls=plotly.utils.PlotlyJSONEncoder)
|
522 |
+
|
523 |
+
# --- Counterfactual Generation ---
|
524 |
+
# (Keep generate_comprehensive_counterfactuals as it was)
|
525 |
+
def generate_comprehensive_counterfactuals(df_original, model, target_type, shap_data, shap_error):
|
526 |
+
"""Generate simplified counterfactuals based on SHAP values and feature modifiability."""
|
527 |
+
print(f"Generating counterfactuals for {target_type}...")
|
528 |
+
counterfactuals = []
|
529 |
+
if shap_error or not isinstance(shap_data, list) or not shap_data or (shap_data[0] and shap_data[0].get("error")):
|
530 |
+
print(f"Skipping counterfactuals for {target_type} due to SHAP errors or missing data.")
|
531 |
+
return [{"error": f"Could not generate counterfactuals for {target_type} because SHAP data is missing or invalid."}]
|
532 |
+
|
533 |
+
df = df_original.copy()
|
534 |
+
try:
|
535 |
+
current_pred_score = max(0.0, min(1.0, float(model.predict(df)[0])))
|
536 |
+
goal_direction = "decrease"
|
537 |
+
non_modifiable_features = {"Age", "Gender", "Why in Hospital", "Hospital Days", "Was in ICU (1=Yes)", "ICU Days"}
|
538 |
+
valid_shap_data = [item for item in shap_data if "error" not in item]
|
539 |
+
shap_data_sorted = sorted(valid_shap_data, key=lambda x: abs(x.get('shap_value', 0)), reverse=True)
|
540 |
+
cf_count = 0
|
541 |
+
max_cfs = 5
|
542 |
+
|
543 |
+
for feature_data in shap_data_sorted:
|
544 |
+
if cf_count >= max_cfs: break
|
545 |
+
feature = feature_data.get('feature')
|
546 |
+
shap_value = feature_data.get('shap_value', 0)
|
547 |
+
current_display_value = feature_data.get('feature_value', 'N/A')
|
548 |
+
original_numeric_value = feature_data.get('numeric_value')
|
549 |
+
|
550 |
+
if feature is None or original_numeric_value is None or abs(shap_value) < 0.01 or feature in non_modifiable_features:
|
551 |
+
continue
|
552 |
+
|
553 |
+
feature_change_direction = "decrease" if shap_value > 0 else "increase"
|
554 |
+
could_change = False
|
555 |
+
new_value = None
|
556 |
+
suggested_val_str = ""
|
557 |
+
potential_outcome = "N/A"
|
558 |
+
notes = ""
|
559 |
+
impact_magnitude = "Minor"
|
560 |
+
|
561 |
+
# --- Specific Logic for Modifiable Features ---
|
562 |
+
if feature.startswith("Took Medicine Day"):
|
563 |
+
if original_numeric_value == 0 and shap_value > 0.01:
|
564 |
+
suggested_val_str = "Ensure Adherence ('Yes')"
|
565 |
+
new_value = 1
|
566 |
+
could_change = True
|
567 |
+
elif feature == "Medicine Availability (0-1)":
|
568 |
+
if shap_value < -0.01 and original_numeric_value < 0.95:
|
569 |
+
new_value = 1.0
|
570 |
+
suggested_val_str = "Improve towards High (1.0)"
|
571 |
+
could_change = True
|
572 |
+
else: # Other numeric modifiable features
|
573 |
+
if (feature_change_direction == "decrease" and shap_value > 0.02) or \
|
574 |
+
(feature_change_direction == "increase" and shap_value < -0.02):
|
575 |
+
change_perc = 0.20
|
576 |
+
target_factor = (1 - change_perc) if feature_change_direction == "decrease" else (1 + change_perc)
|
577 |
+
tentative_new_value = original_numeric_value * target_factor
|
578 |
+
|
579 |
+
# Apply constraints and rounding
|
580 |
+
if feature in ["Number of Medicines", "Total Pills Given"]: new_value = max(1, math.floor(tentative_new_value))
|
581 |
+
elif feature == "Days Medicine Lasts": new_value = max(7, math.floor(tentative_new_value))
|
582 |
+
elif feature == "Cost per Medicine (₹)": new_value = max(10.0, round(tentative_new_value, 0))
|
583 |
+
elif feature == "Total Dosage per Day (mg)": new_value = max(5.0, round(tentative_new_value, 0))
|
584 |
+
else: new_value = round(tentative_new_value, 2) # Fallback
|
585 |
+
|
586 |
+
# Generate suggested string AFTER setting new_value
|
587 |
+
if feature == "Number of Medicines": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{int(new_value)}"
|
588 |
+
elif feature == "Total Pills Given": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{int(new_value)}"
|
589 |
+
elif feature == "Days Medicine Lasts": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{int(new_value)} days"
|
590 |
+
elif feature == "Cost per Medicine (₹)": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~₹{new_value:.0f}"
|
591 |
+
elif feature == "Total Dosage per Day (mg)": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{new_value:.0f} mg"
|
592 |
+
else: suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{new_value:.2f}"
|
593 |
+
|
594 |
+
|
595 |
+
# Check if change is significant
|
596 |
+
if abs(new_value - original_numeric_value) > 0.01 * abs(original_numeric_value) + 0.1:
|
597 |
+
could_change = True
|
598 |
+
else:
|
599 |
+
new_value = None # Don't simulate small changes
|
600 |
+
|
601 |
+
# --- Simulate Outcome ---
|
602 |
+
if could_change and new_value is not None and suggested_val_str:
|
603 |
+
df_cf = df.copy(); df_cf[feature] = new_value
|
604 |
+
try:
|
605 |
+
cf_pred_score = max(0.0, min(1.0, float(model.predict(df_cf)[0])))
|
606 |
+
change_in_score = cf_pred_score - current_pred_score
|
607 |
+
outcome_desc = f"Est. new risk: {cf_pred_score:.1%}"
|
608 |
+
change_desc = f"({change_in_score:+.1%})"
|
609 |
+
|
610 |
+
if (goal_direction == "decrease" and change_in_score < -0.01):
|
611 |
+
potential_outcome = f"{outcome_desc} {change_desc}"
|
612 |
+
if abs(change_in_score) > 0.10: impact_magnitude = "Significant"
|
613 |
+
elif abs(change_in_score) > 0.05: impact_magnitude = "Moderate"
|
614 |
+
notes = f"This change is predicted to {goal_direction} the {target_type.lower()} risk."
|
615 |
+
counterfactuals.append({
|
616 |
+
"feature": feature, "current_value": current_display_value,
|
617 |
+
"suggested_change": suggested_val_str, "potential_outcome": potential_outcome,
|
618 |
+
"impact_magnitude": impact_magnitude, "risk_type": target_type
|
619 |
+
})
|
620 |
+
cf_count += 1
|
621 |
+
except Exception as sim_e:
|
622 |
+
print(f"Error simulating counterfactual for {feature}: {sim_e}")
|
623 |
+
|
624 |
+
if not counterfactuals:
|
625 |
+
counterfactuals.append({"notes": f"No simple, impactful counterfactual changes identified among top modifiable factors for {target_type} risk."})
|
626 |
+
print(f"Finished generating {cf_count} counterfactual entries for {target_type}.")
|
627 |
+
return counterfactuals
|
628 |
+
except Exception as e:
|
629 |
+
print(f"Error generating counterfactuals for {target_type}: {str(e)}")
|
630 |
+
traceback.print_exc()
|
631 |
+
return [{"error": f"Could not generate counterfactuals for {target_type}: {str(e)}"}]
|
632 |
+
|
633 |
+
# --- Recommendation Generation ---
|
634 |
+
# (Keep generate_recommendations as it was)
|
635 |
+
def generate_recommendations(shap_data_na, shap_data_r, pred_na_prob, pred_r_prob, shap_error_na, shap_error_r):
|
636 |
+
"""Generate actionable recommendations based on risk levels and SHAP factors."""
|
637 |
+
print("Generating recommendations...")
|
638 |
+
recommendations = []
|
639 |
+
processed_features = set()
|
640 |
+
|
641 |
+
na_shap_valid = not shap_error_na and isinstance(shap_data_na, list) and shap_data_na and "error" not in shap_data_na[0]
|
642 |
+
r_shap_valid = not shap_error_r and isinstance(shap_data_r, list) and shap_data_r and "error" not in shap_data_r[0]
|
643 |
+
|
644 |
+
na_risk_level_info = get_risk_level(pred_na_prob)
|
645 |
+
r_risk_level_info = get_risk_level(pred_r_prob)
|
646 |
+
na_level = na_risk_level_info['level']
|
647 |
+
r_level = r_risk_level_info['level']
|
648 |
+
|
649 |
+
# --- General Recommendations ---
|
650 |
+
if na_level == 'High':
|
651 |
+
recommendations.append({
|
652 |
+
"category": "Overall Non-Adherence", "priority": "Critical",
|
653 |
+
"recommendation": f"High ({na_risk_level_info['percentage']}) non-adherence risk detected.",
|
654 |
+
"action": "Implement intensive adherence support: daily reminders, regimen simplification consult, frequent follow-up (e.g., within 3 days), assess/address specific barriers (cost, access, understanding)."
|
655 |
+
})
|
656 |
+
elif na_level == 'Medium':
|
657 |
+
recommendations.append({
|
658 |
+
"category": "Overall Non-Adherence", "priority": "High",
|
659 |
+
"recommendation": f"Medium ({na_risk_level_info['percentage']}) non-adherence risk detected.",
|
660 |
+
"action": "Provide adherence aids (e.g., pillbox, app reminders), schedule follow-up call within 1 week, reinforce importance of medication."
|
661 |
+
})
|
662 |
+
if r_level == 'High':
|
663 |
+
recommendations.append({
|
664 |
+
"category": "Overall Readmission", "priority": "Critical",
|
665 |
+
"recommendation": f"High ({r_risk_level_info['percentage']}) readmission risk detected.",
|
666 |
+
"action": "Comprehensive discharge plan crucial: schedule follow-up within 7 days, ensure medication reconciliation, consider home health/transitional care referral, detailed patient/caregiver education using teach-back."
|
667 |
+
})
|
668 |
+
elif r_level == 'Medium':
|
669 |
+
recommendations.append({
|
670 |
+
"category": "Overall Readmission", "priority": "High",
|
671 |
+
"recommendation": f"Medium ({r_risk_level_info['percentage']}) readmission risk detected.",
|
672 |
+
"action": "Enhanced discharge process: ensure clear instructions (written/verbal), schedule follow-up within 10-14 days, confirm patient understanding of red flags and who to contact."
|
673 |
+
})
|
674 |
+
|
675 |
+
# --- Specific Recommendations ---
|
676 |
+
if not na_shap_valid and not r_shap_valid:
|
677 |
+
recommendations.append({
|
678 |
+
"category": "Explanations", "priority": "Medium",
|
679 |
+
"recommendation": "Detailed factor analysis unavailable due to SHAP errors.",
|
680 |
+
"action": "Focus on general risk levels and standard protocols. Investigate SHAP calculation issues if persistent."})
|
681 |
+
return recommendations # Early exit
|
682 |
+
|
683 |
+
combined_shap = {}
|
684 |
+
valid_shap_list = ([item for item in shap_data_na if "error" not in item] if na_shap_valid else []) + \
|
685 |
+
([item for item in shap_data_r if "error" not in item] if r_shap_valid else [])
|
686 |
+
for item in valid_shap_list:
|
687 |
+
feature = item.get('feature')
|
688 |
+
abs_shap = abs(item.get('shap_value', 0))
|
689 |
+
if feature:
|
690 |
+
current_entry = combined_shap.get(feature, {'abs_shap_sum': 0, 'data_na': None, 'data_r': None})
|
691 |
+
current_entry['abs_shap_sum'] += abs_shap
|
692 |
+
if item in shap_data_na: current_entry['data_na'] = item
|
693 |
+
if item in shap_data_r: current_entry['data_r'] = item
|
694 |
+
combined_shap[feature] = current_entry
|
695 |
+
sorted_features = sorted(combined_shap.keys(), key=lambda f: combined_shap[f]['abs_shap_sum'], reverse=True)
|
696 |
+
|
697 |
+
rec_count = 0
|
698 |
+
max_recs = 7
|
699 |
+
shap_threshold = 0.02 # Minimum SHAP value to consider as a driver
|
700 |
+
|
701 |
+
for feature in sorted_features:
|
702 |
+
if rec_count >= max_recs: break
|
703 |
+
if feature in processed_features: continue
|
704 |
+
|
705 |
+
shap_entry = combined_shap.get(feature, {})
|
706 |
+
na_item = shap_entry.get('data_na')
|
707 |
+
r_item = shap_entry.get('data_r')
|
708 |
+
na_shap = na_item.get('shap_value', 0) if na_item else 0
|
709 |
+
r_shap = r_item.get('shap_value', 0) if r_item else 0
|
710 |
+
item_for_val = na_item or r_item
|
711 |
+
current_val = item_for_val.get('feature_value', 'N/A') if item_for_val else 'N/A'
|
712 |
+
numeric_val = item_for_val.get('numeric_value') if item_for_val else None
|
713 |
+
|
714 |
+
is_na_driver = na_shap > shap_threshold
|
715 |
+
is_r_driver = r_shap > shap_threshold
|
716 |
+
if not (is_na_driver or is_r_driver): continue # Skip if not a driver for either
|
717 |
+
|
718 |
+
max_abs_shap = max(abs(na_shap), abs(r_shap))
|
719 |
+
priority = "Medium"
|
720 |
+
if max_abs_shap > 0.1: priority = "High"
|
721 |
+
if (is_na_driver and na_level == 'High') or (is_r_driver and r_level == 'High'): priority = "High" if priority == "Medium" else priority
|
722 |
+
if (na_level == 'High' and na_shap > 0.15) or (r_level == 'High' and r_shap > 0.15): priority = "Critical"
|
723 |
+
|
724 |
+
rec_made = False
|
725 |
+
action_text = ""
|
726 |
+
rec_category = ""
|
727 |
+
rec_recommendation = ""
|
728 |
+
|
729 |
+
# --- Generate Recommendation Content ---
|
730 |
+
# (Combine NA and R logic for the same feature where applicable)
|
731 |
+
if feature.startswith("Took Medicine Day") and current_val == "No":
|
732 |
+
if is_na_driver: # Primarily an adherence issue
|
733 |
+
priority = "Critical"; rec_category = "Early Adherence Failure"
|
734 |
+
rec_recommendation = f"Missed medication on {feature.split('(')[0].strip()} is a strong indicator of future non-adherence."
|
735 |
+
action_text = "Immediate intervention required: follow-up call TODAY, assess reasons, counsel, establish reminders/support."
|
736 |
+
rec_made = True
|
737 |
+
elif feature == "Number of Medicines":
|
738 |
+
if is_na_driver or is_r_driver:
|
739 |
+
rec_category = "Medication Complexity"
|
740 |
+
rec_recommendation = f"High number of medicines ({current_val}) associated with increased risk ({'NA' if is_na_driver else ''}{'& R' if is_r_driver else ''})."
|
741 |
+
action_text = "Review list for simplification/consolidation. Consider pharmacist consult for polypharmacy review."
|
742 |
+
rec_made = True
|
743 |
+
elif feature == "Cost per Medicine (₹)":
|
744 |
+
try: cost_val = float(numeric_val) if numeric_val is not None else 0
|
745 |
+
except: cost_val = 0
|
746 |
+
if cost_val > 100 and is_na_driver: # Primarily adherence cost barrier
|
747 |
+
rec_category = "Medication Cost Barrier"
|
748 |
+
rec_recommendation = f"High average medication cost (₹{current_val}) may be a barrier to adherence."
|
749 |
+
action_text = "Discuss cost concerns. Explore generics, assistance programs, or lower-cost alternatives."
|
750 |
+
rec_made = True
|
751 |
+
elif feature == "Medicine Availability (0-1)":
|
752 |
+
try: avail_val = float(numeric_val) if numeric_val is not None else 1.0
|
753 |
+
except: avail_val = 1.0
|
754 |
+
if avail_val < 0.5 and is_na_driver: # Primarily adherence access issue
|
755 |
+
rec_category = "Medication Access Issue"
|
756 |
+
rec_recommendation = f"Reported low medication availability ({current_val}) likely hinders adherence."
|
757 |
+
action_text = "Verify pharmacy stock pre-discharge. Help patient identify reliable source or discuss alternatives."
|
758 |
+
rec_made = True
|
759 |
+
elif feature == "ICU Days":
|
760 |
+
try: icu_days_val = int(numeric_val) if numeric_val is not None else 0
|
761 |
+
except: icu_days_val = 0
|
762 |
+
if icu_days_val > 0 and is_r_driver: # Primarily readmission risk
|
763 |
+
rec_priority = "High" if icu_days_val > 2 else priority
|
764 |
+
rec_category = "ICU History Impact"; priority = rec_priority
|
765 |
+
rec_recommendation = f"Prior ICU stay ({current_val} days) significantly increases readmission risk."
|
766 |
+
action_text = "Intensive post-discharge support: early follow-up (≤7 days), consider transitional care/home health, meticulous med review & education."
|
767 |
+
rec_made = True
|
768 |
+
elif feature == "Hospital Days":
|
769 |
+
try: days_val = int(numeric_val) if numeric_val is not None else 0
|
770 |
+
except: days_val = 0
|
771 |
+
if days_val > 7 and is_r_driver: # Primarily readmission risk
|
772 |
+
rec_category = "Length of Stay Impact"
|
773 |
+
rec_recommendation = f"Longer hospital stay ({current_val} days) associated with increased readmission risk."
|
774 |
+
action_text = "Reinforces need for thorough discharge planning, clear instructions, med reconciliation, and prompt follow-up (≤10 days)."
|
775 |
+
rec_made = True
|
776 |
+
elif feature == "Age":
|
777 |
+
try: age_val = int(numeric_val) if numeric_val is not None else 0
|
778 |
+
except: age_val = 0
|
779 |
+
if age_val > 75 and is_r_driver: # Readmission context factor
|
780 |
+
priority = "Medium"; rec_category = "Age Factor (Context)"
|
781 |
+
rec_recommendation = f"Patient's age ({current_val}) contributes moderately to readmission risk."
|
782 |
+
action_text = "Consider age-related needs in discharge plan (support, mobility, cognition, simple instructions)."
|
783 |
+
rec_made = True
|
784 |
+
elif feature == "Why in Hospital":
|
785 |
+
if is_r_driver: # Readmission context factor
|
786 |
+
priority = "Medium" if priority=="Standard" else priority # Elevate slightly but not critical usually
|
787 |
+
rec_category = "Diagnosis Factor (Context)"
|
788 |
+
rec_recommendation = f"Primary diagnosis ({current_val}) contributes to risk."
|
789 |
+
action_text = f"Ensure condition-specific discharge education (red flags, follow-up) and care plan for managing '{current_val.lower()}' are emphasized."
|
790 |
+
rec_made = True
|
791 |
+
|
792 |
+
# Append the recommendation if one was generated
|
793 |
+
if rec_made:
|
794 |
+
recommendations.append({
|
795 |
+
"category": rec_category,
|
796 |
+
"priority": priority,
|
797 |
+
"recommendation": rec_recommendation,
|
798 |
+
"action": action_text
|
799 |
+
})
|
800 |
+
processed_features.add(feature)
|
801 |
+
rec_count += 1
|
802 |
+
|
803 |
+
# Add default recommendation if few specific ones generated but risk elevated
|
804 |
+
if (na_level in ['High', 'Medium'] or r_level in ['High', 'Medium']) and len(recommendations) < 3:
|
805 |
+
recommendations.append({
|
806 |
+
"category": "General Follow-up", "priority": "Medium",
|
807 |
+
"recommendation": "Overall risk is elevated. Review standard discharge protocols.",
|
808 |
+
"action": "Ensure robust basics: use teach-back, confirm follow-up appointments, provide clear contact info."})
|
809 |
+
|
810 |
+
# Sort final recommendations by priority
|
811 |
+
priority_map = {"Critical": 0, "High": 1, "Medium": 2, "Standard": 3, "Info": 4}
|
812 |
+
recommendations.sort(key=lambda x: priority_map.get(x.get("priority", "Medium"), 99))
|
813 |
+
|
814 |
+
print(f"Finished generating {len(recommendations)} recommendations.")
|
815 |
+
return recommendations
|
816 |
+
|
817 |
+
|
818 |
+
# --- Visualization Generation ---
|
819 |
+
# (Keep generate_gauge_charts as it was)
|
820 |
+
def generate_gauge_charts(pred_na_prob, pred_r_prob):
|
821 |
+
"""Generate Plotly gauge chart JSON objects."""
|
822 |
+
print("Generating gauge charts...")
|
823 |
+
gauges = {}
|
824 |
+
color_low = '#28a745'; color_medium = '#ffc107'; color_high = '#dc3545'
|
825 |
+
color_low_bg = '#d4edda'; color_medium_bg = '#fff3cd'; color_high_bg = '#f8d7da'
|
826 |
+
pred_na_prob = max(0.0, min(1.0, pred_na_prob))
|
827 |
+
pred_r_prob = max(0.0, min(1.0, pred_r_prob))
|
828 |
+
common_layout = {'height': 220, 'margin': {'t': 50, 'b': 10, 'l': 20, 'r': 20}, 'font': {'color': "#333", 'family': "Arial, sans-serif", 'size': 12}, 'paper_bgcolor': 'rgba(0,0,0,0)', 'plot_bgcolor': 'rgba(0,0,0,0)', 'autosize': True}
|
829 |
+
common_gauge = {'axis': {'range': [0, 100], 'tickwidth': 1, 'tickcolor': "#aaa", 'tickfont': {'size': 10}}, 'bar': {'color': "rgba(0,0,0,0.1)", 'thickness': 0.3}, 'bgcolor': 'rgba(0,0,0,0)', 'borderwidth': 0, 'steps': [{'range': [0, 30], 'color': color_low_bg}, {'range': [30, 70], 'color': color_medium_bg}, {'range': [70, 100], 'color': color_high_bg}], 'threshold': {'line': {'color': '#666', 'width': 4}, 'thickness': 0.75, 'value': 0}}
|
830 |
+
|
831 |
+
na_risk_info = get_risk_level(pred_na_prob)
|
832 |
+
na_value_perc = round(pred_na_prob * 100)
|
833 |
+
na_gauge_data = common_gauge.copy(); na_gauge_data['threshold']['value'] = na_value_perc
|
834 |
+
gauge_bar_color_na = {'green': color_low, 'yellow': color_medium, 'red': color_high}.get(na_risk_info['color'], '#888')
|
835 |
+
na_gauge_data['bar']['color'] = gauge_bar_color_na
|
836 |
+
gauges['non_adherence'] = json.dumps({'data': [{'type': 'indicator', 'mode': 'gauge+number', 'value': na_value_perc, 'title': {'text': f"<b>Non-Adherence Risk</b><br><span style='font-size:0.9em;'>Level: {na_risk_info['level']}</span>", 'font': {'size': 14}}, 'gauge': na_gauge_data, 'number': {'suffix': "%", 'font': {'size': 24, 'color': gauge_bar_color_na}}}], 'layout': common_layout.copy()}, cls=plotly.utils.PlotlyJSONEncoder)
|
837 |
+
|
838 |
+
r_risk_info = get_risk_level(pred_r_prob)
|
839 |
+
r_value_perc = round(pred_r_prob * 100)
|
840 |
+
r_gauge_data = common_gauge.copy(); r_gauge_data['threshold']['value'] = r_value_perc
|
841 |
+
gauge_bar_color_r = {'green': color_low, 'yellow': color_medium, 'red': color_high}.get(r_risk_info['color'], '#888')
|
842 |
+
r_gauge_data['bar']['color'] = gauge_bar_color_r
|
843 |
+
gauges['readmission'] = json.dumps({'data': [{'type': 'indicator', 'mode': 'gauge+number', 'value': r_value_perc, 'title': {'text': f"<b>Readmission Risk</b><br><span style='font-size:0.9em;'>Level: {r_risk_info['level']}</span>", 'font': {'size': 14}}, 'gauge': r_gauge_data, 'number': {'suffix': "%", 'font': {'size': 24, 'color': gauge_bar_color_r}}}], 'layout': common_layout.copy()}, cls=plotly.utils.PlotlyJSONEncoder)
|
844 |
+
|
845 |
+
print("Finished generating gauge charts.")
|
846 |
+
return gauges
|
847 |
+
|
848 |
+
|
849 |
+
# In app.py
|
850 |
+
def generate_additional_visualizations(df_user, shap_data_na, shap_data_r,
|
851 |
+
base_value_na, base_value_r,
|
852 |
+
shap_error_na, shap_error_r):
|
853 |
+
"""Generate additional Plotly JSON for various XAI visualizations, including
|
854 |
+
waterfall, heatmap, feature comparison, intervention impact, and network graph."""
|
855 |
+
print("Generating additional visualizations...")
|
856 |
+
visualizations = {}
|
857 |
+
plot_bgcolor = 'rgba(0,0,0,0)'
|
858 |
+
paper_bgcolor = 'rgba(0,0,0,0)'
|
859 |
+
font_family = "Arial, sans-serif"
|
860 |
+
font_color = "#333"
|
861 |
+
|
862 |
+
# Determine validity
|
863 |
+
na_shap_valid = not shap_error_na and isinstance(shap_data_na, list) and shap_data_na and "error" not in shap_data_na[0]
|
864 |
+
r_shap_valid = not shap_error_r and isinstance(shap_data_r, list) and shap_data_r and "error" not in shap_data_r[0]
|
865 |
+
|
866 |
+
# --- 1. Waterfall Charts ---
|
867 |
+
def create_waterfall(shap_data, base_value, title_suffix, is_valid):
|
868 |
+
if not is_valid or base_value is None:
|
869 |
+
return create_placeholder_plot(f'Waterfall: {title_suffix}')
|
870 |
+
try:
|
871 |
+
valid = [item for item in shap_data if isinstance(item, dict) and "error" not in item]
|
872 |
+
top = sorted(valid, key=lambda x: abs(x['shap_value']), reverse=True)[:10]
|
873 |
+
values = [item['shap_value'] for item in top]
|
874 |
+
labels = [
|
875 |
+
f"{item['feature'][:25]}{'...' if len(item['feature'])>25 else ''} = {item['feature_value']}"
|
876 |
+
for item in top
|
877 |
+
]
|
878 |
+
measures = ["absolute"] + ["relative"]*len(top) + ["total"]
|
879 |
+
y = ["Average Model Output"] + labels + ["Final Prediction"]
|
880 |
+
x = [base_value] + values + [base_value + sum(values)]
|
881 |
+
text = [f"{v:+.3f}" if 0< i < len(x)-1 else f"{v:.3f}" for i, v in enumerate(x)]
|
882 |
+
|
883 |
+
fig = go.Figure(go.Waterfall(
|
884 |
+
orientation="h", measure=measures, y=y, x=x, text=text,
|
885 |
+
textposition="outside", base=0,
|
886 |
+
connector={"line":{"color":"rgb(150,150,150)", "width":1}},
|
887 |
+
increasing={"marker":{"color":"#dc3545","line":{"width":1,"color":"#dc3545"}}},
|
888 |
+
decreasing={"marker":{"color":"#28a745","line":{"width":1,"color":"#28a745"}}},
|
889 |
+
totals={"marker":{"color":"#007bff","line":{"width":1,"color":"#007bff"}}}
|
890 |
+
))
|
891 |
+
# autoscale axis
|
892 |
+
all_vals = x
|
893 |
+
mn, mx = min(all_vals), max(all_vals)
|
894 |
+
pad = (mx-mn)*0.15 if (mx-mn)>0.01 else 0.1
|
895 |
+
fig.update_layout(
|
896 |
+
title=f'How Factors Contribute to {title_suffix}',
|
897 |
+
showlegend=False,
|
898 |
+
height=max(450,40*len(y)),
|
899 |
+
margin=dict(t=50,l=250,r=50,b=50),
|
900 |
+
yaxis={'autorange':'reversed','automargin':True},
|
901 |
+
xaxis={'range':[mn-pad,mx+pad],'title':'Contribution to Score'},
|
902 |
+
plot_bgcolor=plot_bgcolor,
|
903 |
+
paper_bgcolor=paper_bgcolor,
|
904 |
+
font=dict(family=font_family,color=font_color),
|
905 |
+
autosize=True
|
906 |
+
)
|
907 |
+
return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
|
908 |
+
except Exception as e:
|
909 |
+
print(f"Error creating waterfall {title_suffix}: {e}")
|
910 |
+
return create_placeholder_plot(f'Waterfall: {title_suffix}')
|
911 |
+
|
912 |
+
visualizations['waterfall_na'] = create_waterfall(shap_data_na, base_value_na, 'Non-Adherence Risk', na_shap_valid)
|
913 |
+
visualizations['waterfall_r'] = create_waterfall(shap_data_r, base_value_r, 'Readmission Risk', r_shap_valid)
|
914 |
+
|
915 |
+
# --- 2. Combined SHAP for heatmap & bar chart ---
|
916 |
+
# collect valid shap items
|
917 |
+
valid_na = [i for i in shap_data_na if na_shap_valid and "error" not in i]
|
918 |
+
valid_r = [i for i in shap_data_r if r_shap_valid and "error" not in i]
|
919 |
+
combined_impact = {}
|
920 |
+
for i in (valid_na + valid_r):
|
921 |
+
f=i['feature']; combined_impact[f]=combined_impact.get(f,0)+abs(i['shap_value'])
|
922 |
+
top_feats = sorted(combined_impact, key=lambda f:combined_impact[f], reverse=True)[:10]
|
923 |
+
|
924 |
+
# 2a. Heatmap
|
925 |
+
try:
|
926 |
+
z=[]; text=[]
|
927 |
+
for f in top_feats:
|
928 |
+
na_v = next((i['shap_value'] for i in valid_na if i['feature']==f),0)
|
929 |
+
r_v = next((i['shap_value'] for i in valid_r if i['feature']==f),0)
|
930 |
+
z.append([na_v,r_v])
|
931 |
+
fv_na = next((i['feature_value'] for i in valid_na if i['feature']==f),None)
|
932 |
+
fv_r = next((i['feature_value'] for i in valid_r if i['feature']==f),None)
|
933 |
+
display = fv_na if fv_na is not None else fv_r
|
934 |
+
text.append([f"<b>{f}={display}</b><br>NA: {na_v:.3f}", f"<b>{f}={display}</b><br>R: {r_v:.3f}"])
|
935 |
+
hm = {
|
936 |
+
'data':[{
|
937 |
+
'type':'heatmap','z':z,'text':text,'hoverinfo':'text',
|
938 |
+
'x':['Non-Adherence','Readmission'],'y':top_feats,
|
939 |
+
'colorscale':'RdBu_r','zmid':0,'xgap':1,'ygap':1
|
940 |
+
}],
|
941 |
+
'layout':{
|
942 |
+
'title':'Top Factor Impact Comparison',
|
943 |
+
'height':max(400,40*len(top_feats)),
|
944 |
+
'margin':{'t':60,'l':250,'b':80,'r':50},
|
945 |
+
'yaxis':{'autorange':'reversed','automargin':True},
|
946 |
+
'plot_bgcolor':plot_bgcolor,'paper_bgcolor':paper_bgcolor,
|
947 |
+
'font':{'family':font_family,'color':font_color},
|
948 |
+
'autosize':True
|
949 |
+
}
|
950 |
+
}
|
951 |
+
visualizations['risk_heatmap']=json.dumps(hm, cls=plotly.utils.PlotlyJSONEncoder)
|
952 |
+
except Exception as e:
|
953 |
+
print(f"Error heatmap: {e}")
|
954 |
+
visualizations['risk_heatmap']=create_placeholder_plot('Risk Factor Heatmap')
|
955 |
+
|
956 |
+
# 2b. Bar chart
|
957 |
+
try:
|
958 |
+
bars = {
|
959 |
+
'data':[
|
960 |
+
{'type':'bar','x':top_feats,'y':[abs(next((i['shap_value'] for i in valid_na if i['feature']==f),0)) for f in top_feats],
|
961 |
+
'name':'NA Impact'},
|
962 |
+
{'type':'bar','x':top_feats,'y':[abs(next((i['shap_value'] for i in valid_r if i['feature']==f),0)) for f in top_feats],
|
963 |
+
'name':'Readmission Impact'}
|
964 |
+
],
|
965 |
+
'layout':{
|
966 |
+
'title':'Feature Importance (Absolute SHAP)',
|
967 |
+
'barmode':'group','bargap':0.15,'bargroupgap':0.1,
|
968 |
+
'height':450,'margin':{'t':50,'b':150,'l':60,'r':20},
|
969 |
+
'xaxis':{'tickangle':-45,'automargin':True},
|
970 |
+
'plot_bgcolor':plot_bgcolor,'paper_bgcolor':paper_bgcolor,
|
971 |
+
'font':{'family':font_family,'color':font_color},
|
972 |
+
'autosize':True
|
973 |
+
}
|
974 |
+
}
|
975 |
+
visualizations['feature_comparison']=json.dumps(bars, cls=plotly.utils.PlotlyJSONEncoder)
|
976 |
+
except Exception as e:
|
977 |
+
print(f"Error bar chart: {e}")
|
978 |
+
visualizations['feature_comparison']=create_placeholder_plot('Feature Importance Comparison')
|
979 |
+
|
980 |
+
# --- 3. Intervention Impact ---
|
981 |
+
try:
|
982 |
+
interventions=[]
|
983 |
+
mod_set={'Number of Medicines','Cost per Medicine (₹)','Days Medicine Lasts',
|
984 |
+
'Total Dosage per Day (mg)','Total Pills Given','Medicine Availability (0-1)',
|
985 |
+
'Took Medicine Day 1 (1=Yes)','Took Medicine Day 2 (1=Yes)',
|
986 |
+
'Took Medicine Day 3 (1=Yes)'}
|
987 |
+
map_label={
|
988 |
+
'Number of Medicines':'Reduce # Medicines',
|
989 |
+
'Cost per Medicine (₹)':'Reduce Med Cost',
|
990 |
+
'Days Medicine Lasts':'Optimize Refill',
|
991 |
+
'Total Dosage per Day (mg)':'Optimize Dosage',
|
992 |
+
'Total Pills Given':'Reduce Pill Burden',
|
993 |
+
'Medicine Availability (0-1)':'Improve Availability',
|
994 |
+
**{f:f.replace('Took Medicine','Ensure') for f in mod_set if f.startswith('Took Medicine')}
|
995 |
+
}
|
996 |
+
thresh=0.015
|
997 |
+
for f in mod_set:
|
998 |
+
na_v=abs(next((i['shap_value'] for i in valid_na if i['feature']==f),0))
|
999 |
+
r_v =abs(next((i['shap_value'] for i in valid_r if i['feature']==f),0))
|
1000 |
+
if na_v>thresh or r_v>thresh:
|
1001 |
+
interventions.append({'intervention':map_label.get(f,f),'na':na_v,'r':r_v})
|
1002 |
+
topi=sorted(interventions, key=lambda x:x['na']+x['r'],reverse=True)[:6]
|
1003 |
+
if topi:
|
1004 |
+
chart={'data':[
|
1005 |
+
{'type':'bar','orientation':'h','y':[i['intervention'] for i in topi],
|
1006 |
+
'x':[i['na'] for i in topi],'name':'NA Reduction','text':[f"{i['na']:.3f}" for i in topi],
|
1007 |
+
'textposition':'outside'},
|
1008 |
+
{'type':'bar','orientation':'h','y':[i['intervention'] for i in topi],
|
1009 |
+
'x':[i['r'] for i in topi],'name':'R Reduction','text':[f"{i['r']:.3f}" for i in topi],
|
1010 |
+
'textposition':'outside'}
|
1011 |
+
],
|
1012 |
+
'layout':{
|
1013 |
+
'title':'Top Potential Intervention Impacts',
|
1014 |
+
'barmode':'group','height':max(350,50*len(topi)),
|
1015 |
+
'margin':{'t':50,'l':200,'b':50,'r':50},
|
1016 |
+
'yaxis':{'autorange':'reversed','automargin':True},
|
1017 |
+
'plot_bgcolor':plot_bgcolor,'paper_bgcolor':paper_bgcolor,
|
1018 |
+
'font':{'family':font_family,'color':font_color},
|
1019 |
+
'autosize':True
|
1020 |
+
}}
|
1021 |
+
visualizations['intervention_impact']=json.dumps(chart, cls=plotly.utils.PlotlyJSONEncoder)
|
1022 |
+
else:
|
1023 |
+
visualizations['intervention_impact']=create_placeholder_plot(
|
1024 |
+
'Potential Intervention Impact',
|
1025 |
+
message="No significant interventions identified."
|
1026 |
+
)
|
1027 |
+
except Exception as e:
|
1028 |
+
print(f"Error interventions: {e}")
|
1029 |
+
visualizations['intervention_impact']=create_placeholder_plot('Potential Intervention Impact')
|
1030 |
+
|
1031 |
+
# --- 4. Network Graph ---
|
1032 |
+
try:
|
1033 |
+
import networkx as nx
|
1034 |
+
G=nx.Graph()
|
1035 |
+
# Use combined_impact from above
|
1036 |
+
N=8
|
1037 |
+
topN=sorted(combined_impact, key=lambda f:combined_impact[f], reverse=True)[:N]
|
1038 |
+
for f in topN:
|
1039 |
+
G.add_node(f, size=combined_impact[f])
|
1040 |
+
for i,a in enumerate(topN):
|
1041 |
+
for b in topN[i+1:]:
|
1042 |
+
w=combined_impact[a]+combined_impact[b]
|
1043 |
+
G.add_edge(a,b,weight=w)
|
1044 |
+
pos=nx.spring_layout(G,k=0.5,iterations=50,seed=42)
|
1045 |
+
ex,ey=[],[]
|
1046 |
+
for u,v in G.edges():
|
1047 |
+
x0,y0=pos[u]; x1,y1=pos[v]
|
1048 |
+
ex+=[x0,x1,None]; ey+=[y0,y1,None]
|
1049 |
+
nx_,ny_=[],[]
|
1050 |
+
ns=[G.nodes[n]['size']*50 for n in G.nodes()]
|
1051 |
+
for n in G.nodes():
|
1052 |
+
x,y=pos[n]
|
1053 |
+
nx_.append(x); ny_.append(y)
|
1054 |
+
net={
|
1055 |
+
'data':[
|
1056 |
+
{'type':'scatter','x':ex,'y':ey,'mode':'lines','line':{'width':1,'color':'#888'},'hoverinfo':'none'},
|
1057 |
+
{'type':'scatter','x':nx_,'y':ny_,'mode':'markers+text',
|
1058 |
+
'marker':{'size':ns,'color':'#1f77b4','opacity':0.8},
|
1059 |
+
'text':list(G.nodes()),'textposition':'top center','hoverinfo':'text'}
|
1060 |
+
],
|
1061 |
+
'layout':{
|
1062 |
+
'title':'Risk Factor Network',
|
1063 |
+
'showlegend':False,
|
1064 |
+
'xaxis':{'visible':False},'yaxis':{'visible':False},
|
1065 |
+
'plot_bgcolor':paper_bgcolor,'paper_bgcolor':paper_bgcolor,
|
1066 |
+
'margin':{'l':20,'r':20,'t':40,'b':20},'autosize':True
|
1067 |
+
}
|
1068 |
+
}
|
1069 |
+
visualizations['network_graph']=json.dumps(net, cls=plotly.utils.PlotlyJSONEncoder)
|
1070 |
+
except Exception as e:
|
1071 |
+
print(f"Error creating network graph: {e}")
|
1072 |
+
visualizations['network_graph']=create_placeholder_plot(
|
1073 |
+
'Risk Factor Network',
|
1074 |
+
message="Network data unavailable."
|
1075 |
+
)
|
1076 |
+
|
1077 |
+
print("Finished generating additional visualizations.")
|
1078 |
+
return visualizations
|
1079 |
+
|
1080 |
+
|
1081 |
+
# --- Main Execution ---
|
1082 |
+
if __name__ == '__main__':
|
1083 |
+
print("Starting Flask application...")
|
1084 |
+
port = int(os.environ.get("PORT", 5000))
|
1085 |
+
# Set debug=False when deploying
|
1086 |
+
# Use debug=True for local development ONLY
|
1087 |
+
app.run(debug=False, host='0.0.0.0', port=port)
|