rajkhanke commited on
Commit
3e4cf4b
·
verified ·
1 Parent(s): ccbb8a0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1087 -0
app.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import joblib
4
+ import shap
5
+ import json
6
+ import plotly
7
+ import plotly.graph_objs as go
8
+ import plotly.express as px
9
+ from flask import Flask, render_template, request, jsonify
10
+ import math
11
+ import networkx as nx
12
+ import traceback # For detailed error logging
13
+ import os # For environment variables
14
+
15
+ app = Flask(__name__)
16
+
17
+ # --- Load Models ---
18
+ MODEL_PATH_NA = 'non-adherence_XAI.pkl'
19
+ MODEL_PATH_R = 'readmission_XAI.pkl'
20
+ try:
21
+ model_na = joblib.load(MODEL_PATH_NA)
22
+ model_r = joblib.load(MODEL_PATH_R)
23
+ print("Models loaded successfully.")
24
+ if not (hasattr(model_na, 'predict') and not hasattr(model_na, 'predict_proba')):
25
+ print(f"Warning: Model {MODEL_PATH_NA} might not be a regressor.")
26
+ if not (hasattr(model_r, 'predict') and not hasattr(model_r, 'predict_proba')):
27
+ print(f"Warning: Model {MODEL_PATH_R} might not be a regressor.")
28
+
29
+ except FileNotFoundError as e:
30
+ print(f"FATAL ERROR: Model file not found: {e}. Make sure the .pkl files are in the correct directory.")
31
+ print("Attempted paths:", os.path.abspath(MODEL_PATH_NA), os.path.abspath(MODEL_PATH_R))
32
+ exit()
33
+ except Exception as e:
34
+ print(f"FATAL ERROR: An unexpected error occurred loading models: {e}")
35
+ traceback.print_exc()
36
+ exit()
37
+
38
+ # --- Mappings ---
39
+ gender_map = {'Female': 0, 'Male': 1}
40
+ why_map = {
41
+ 'Bone': 0, 'Brain': 1, 'Heart': 2, 'Infection': 3,
42
+ 'Lung': 4, 'Stomach': 5, 'Surgery': 6
43
+ }
44
+ yes_no_map = {'Yes': 1, 'No': 0}
45
+ reverse_gender_map = {v: k for k, v in gender_map.items()}
46
+ reverse_why_map = {v: k for k, v in why_map.items()}
47
+ reverse_yes_no_map = {v: k for k, v in yes_no_map.items()}
48
+
49
+ # --- Feature Order (Crucial for model input) ---
50
+ feature_order = [
51
+ 'Age', 'Gender', 'Why in Hospital', 'Hospital Days', 'Was in ICU (1=Yes)',
52
+ 'ICU Days', 'Number of Medicines', 'Cost per Medicine (₹)', 'Days Medicine Lasts',
53
+ 'Total Dosage per Day (mg)', 'Total Pills Given', 'Medicine Availability (0-1)',
54
+ 'Took Medicine Day 1 (1=Yes)', 'Took Medicine Day 2 (1=Yes)', 'Took Medicine Day 3 (1=Yes)'
55
+ ]
56
+
57
+ # --- Risk Level Logic ---
58
+ def get_risk_level(score):
59
+ """Categorizes risk score into Low, Medium, High."""
60
+ try:
61
+ clamped_score = max(0.0, min(1.0, float(score)))
62
+ except (ValueError, TypeError):
63
+ print(f"Warning: Invalid score '{score}' received. Defaulting to 0.")
64
+ clamped_score = 0.0
65
+
66
+ percentage = round(clamped_score * 100)
67
+ if clamped_score < 0.3:
68
+ return {'level': 'Low', 'color': 'green', 'percentage': f"{percentage}%"}
69
+ elif clamped_score < 0.7:
70
+ return {'level': 'Medium', 'color': 'yellow', 'percentage': f"{percentage}%"}
71
+ else:
72
+ return {'level': 'High', 'color': 'red', 'percentage': f"{percentage}%"}
73
+
74
+ # --- Feature Metadata for UI ---
75
+ # (Keep the feature_info dictionary as it was in the previous version)
76
+ feature_info = {
77
+ 'Age': {
78
+ 'description': 'Patient age in years',
79
+ 'question': 'How old is the patient?',
80
+ 'help_text': 'Age is a significant factor in both medication adherence and hospital readmission risk.',
81
+ 'ideal_range': '18-100', 'type': 'number'
82
+ },
83
+ 'Gender': {
84
+ 'description': 'Patient gender',
85
+ 'question': 'What is the patient\'s gender?',
86
+ 'help_text': 'Gender can influence medication adherence patterns and readmission risk for certain conditions.',
87
+ 'options': list(gender_map.keys()), 'type': 'select'
88
+ },
89
+ 'Why in Hospital': {
90
+ 'description': 'Primary reason for hospitalization',
91
+ 'question': 'What is the primary reason for hospitalization?',
92
+ 'help_text': 'Different conditions have varying impacts on adherence and readmission patterns.',
93
+ 'options': list(why_map.keys()), 'type': 'select'
94
+ },
95
+ 'Hospital Days': {
96
+ 'description': 'Total days spent in hospital during this admission',
97
+ 'question': 'How many days did the patient spend in the hospital?',
98
+ 'help_text': 'Longer hospital stays often correlate with more complex cases and higher readmission risks.',
99
+ 'ideal_range': '1-30', 'type': 'number'
100
+ },
101
+ 'Was in ICU (1=Yes)': {
102
+ 'description': 'Whether patient spent time in ICU',
103
+ 'question': 'Did the patient spend time in the ICU?',
104
+ 'help_text': 'ICU stays indicate higher severity and may impact post-discharge outcomes.',
105
+ 'options': list(yes_no_map.keys()), 'type': 'select'
106
+ },
107
+ 'ICU Days': {
108
+ 'description': 'Total days spent in ICU (if applicable)',
109
+ 'question': 'How many days did the patient spend in the ICU? (Enter 0 if not in ICU)',
110
+ 'help_text': 'Longer ICU stays typically indicate more severe conditions requiring careful post-discharge planning.',
111
+ 'ideal_range': '0-15', 'type': 'number'
112
+ },
113
+ 'Number of Medicines': {
114
+ 'description': 'Total number of different medications prescribed',
115
+ 'question': 'How many different medications is the patient prescribed?',
116
+ 'help_text': 'Higher medication counts increase complexity and can lead to reduced adherence.',
117
+ 'ideal_range': '1-12', 'type': 'number'
118
+ },
119
+ 'Cost per Medicine (₹)': {
120
+ 'description': 'Average cost per medication in rupees',
121
+ 'question': 'What is the average cost per medication (in ₹)?',
122
+ 'help_text': 'Higher medication costs can impact adherence due to financial constraints.',
123
+ 'ideal_range': '10-1000', 'type': 'number'
124
+ },
125
+ 'Days Medicine Lasts': {
126
+ 'description': 'Number of days the prescribed medication will last',
127
+ 'question': 'How many days will the prescribed medication last?',
128
+ 'help_text': 'Longer durations between refills can affect adherence patterns.',
129
+ 'ideal_range': '7-90', 'type': 'number'
130
+ },
131
+ 'Total Dosage per Day (mg)': {
132
+ 'description': 'Total medication dosage per day in milligrams',
133
+ 'question': 'What is the total medication dosage per day (in mg)?',
134
+ 'help_text': 'Higher daily dosages may indicate more severe conditions and can affect adherence.',
135
+ 'ideal_range': '5-500', 'type': 'number'
136
+ },
137
+ 'Total Pills Given': {
138
+ 'description': 'Total number of pills provided at discharge',
139
+ 'question': 'How many total pills were given to the patient?',
140
+ 'help_text': 'Pill burden is a known factor in medication adherence.',
141
+ 'ideal_range': '10-300', 'type': 'number'
142
+ },
143
+ 'Medicine Availability (0-1)': {
144
+ 'description': 'Availability score of prescribed medication (0=low, 1=high)',
145
+ 'question': 'How available is the medication (0=low, 1=high)?',
146
+ 'help_text': 'Limited availability can significantly impact medication adherence.',
147
+ 'ideal_range': '0-1', 'type': 'number', 'step': '0.01' # Specify step for float input
148
+ },
149
+ 'Took Medicine Day 1 (1=Yes)': {
150
+ 'description': 'Whether patient took medication on day 1 post-discharge',
151
+ 'question': 'Did the patient take their medication on day 1 after discharge?',
152
+ 'help_text': 'Early adherence patterns are strong predictors of overall medication adherence.',
153
+ 'options': list(yes_no_map.keys()), 'type': 'select'
154
+ },
155
+ 'Took Medicine Day 2 (1=Yes)': {
156
+ 'description': 'Whether patient took medication on day 2 post-discharge',
157
+ 'question': 'Did the patient take their medication on day 2 after discharge?',
158
+ 'help_text': 'Consistent adherence in the first days after discharge indicates better overall adherence.',
159
+ 'options': list(yes_no_map.keys()), 'type': 'select'
160
+ },
161
+ 'Took Medicine Day 3 (1=Yes)': {
162
+ 'description': 'Whether patient took medication on day 3 post-discharge',
163
+ 'question': 'Did the patient take their medication on day 3 after discharge?',
164
+ 'help_text': 'Patterns established in the first few days often continue throughout treatment.',
165
+ 'options': list(yes_no_map.keys()), 'type': 'select'
166
+ }
167
+ }
168
+
169
+ # --- Risk Explanations for UI ---
170
+ # (Keep the risk_explanations dictionary as it was)
171
+ risk_explanations = {
172
+ 'non_adherence': {
173
+ 'title': 'Medication Non-Adherence Risk',
174
+ 'description': 'Medication non-adherence refers to the degree to which a patient does not follow their prescription medication regimen as directed by their healthcare provider. This includes missing doses, taking incorrect doses, or stopping treatment early.',
175
+ 'levels': {
176
+ 'Low': 'Patient is likely to follow medication regimen as prescribed with minimal intervention needed.',
177
+ 'Medium': 'Patient may need additional support such as reminders or follow-up calls to ensure adherence.',
178
+ 'High': 'Patient is at significant risk of not taking medications as prescribed. Consider intensive interventions and close monitoring.'
179
+ },
180
+ 'consequences': [
181
+ 'Reduced treatment effectiveness', 'Disease progression or complications',
182
+ 'Increased hospitalization rates', 'Higher healthcare costs', 'Poorer health outcomes'
183
+ ],
184
+ 'interventions': [
185
+ 'Medication reminder systems', 'Simplified medication regimens',
186
+ 'Patient education on importance of adherence', 'Regular follow-up calls',
187
+ 'Addressing barriers (financial, logistical, etc.)'
188
+ ]
189
+ },
190
+ 'readmission': {
191
+ 'title': 'Hospital Readmission Risk',
192
+ 'description': 'Hospital readmission risk refers to the likelihood that a patient will need to return to the hospital within a short period (typically 30 days) after being discharged.',
193
+ 'levels': {
194
+ 'Low': 'Patient has minimal risk factors for readmission and can likely be managed with standard follow-up care.',
195
+ 'Medium': 'Patient has moderate risk of readmission and may benefit from enhanced discharge planning and follow-up.',
196
+ 'High': 'Patient is at high risk for readmission and requires comprehensive discharge planning, early follow-up, and possibly home health services.'
197
+ },
198
+ 'consequences': [
199
+ 'Increased patient suffering', 'Higher healthcare costs', 'Potential complications',
200
+ 'Disruption to patient recovery', 'Reduced hospital quality metrics'
201
+ ],
202
+ 'interventions': [
203
+ 'Comprehensive discharge planning', 'Medication reconciliation',
204
+ 'Early (within 7 days) follow-up appointments', 'Home health services when appropriate',
205
+ 'Patient and caregiver education'
206
+ ]
207
+ }
208
+ }
209
+
210
+
211
+ # --- SHAP Explainer Initialization ---
212
+ # (Keep the get_shap_explainer function as it was)
213
+ shap_explainers = {}
214
+ def get_shap_explainer(model_key, model):
215
+ """Gets or creates a SHAP explainer for a given model."""
216
+ global shap_explainers
217
+ if model_key not in shap_explainers:
218
+ print(f"Initializing SHAP explainer for {model_key}...")
219
+ try:
220
+ if hasattr(model, 'feature_importances_') or 'xgboost' in str(type(model)).lower():
221
+ explainer = shap.TreeExplainer(model)
222
+ shap_explainers[model_key] = explainer
223
+ print(f"SHAP TreeExplainer for {model_key} initialized.")
224
+ else:
225
+ print(f"Warning: Model for {model_key} might not be a tree model. Using generic SHAP Explainer.")
226
+ try:
227
+ explainer = shap.Explainer(model)
228
+ shap_explainers[model_key] = explainer
229
+ print(f"Initialized generic SHAP Explainer for {model_key}.")
230
+ except Exception as gen_e:
231
+ print(f"ERROR initializing generic SHAP explainer for {model_key}: {gen_e}")
232
+ return None
233
+ except Exception as e:
234
+ print(f"ERROR initializing SHAP Explainer for {model_key}: {e}")
235
+ traceback.print_exc()
236
+ return None
237
+ return shap_explainers.get(model_key)
238
+
239
+
240
+ # --- Flask Routes ---
241
+ @app.route('/')
242
+ def home():
243
+ """Renders the main input form page."""
244
+ return render_template('index.html', feature_info=feature_info, feature_order=feature_order, risk_explanations=risk_explanations)
245
+
246
+ @app.route('/predict', methods=['POST'])
247
+ def predict():
248
+ """Handles prediction requests, generates explanations, and returns JSON."""
249
+ try:
250
+ data = request.get_json()
251
+ if not data:
252
+ return jsonify({'success': False, 'error': 'No data received'}), 400
253
+ print("Received data:", data)
254
+
255
+ # --- Input Processing and Validation ---
256
+ user_input = {}
257
+ # Use the more robust key mapping based on getFieldId in JS
258
+ js_key_map = {
259
+ 'age': 'Age', 'gender': 'Gender', 'why-in-hospital': 'Why in Hospital',
260
+ 'hospital-days': 'Hospital Days', 'was-in-icu-1yes': 'Was in ICU (1=Yes)',
261
+ 'icu-days': 'ICU Days', 'number-of-medicines': 'Number of Medicines',
262
+ 'cost-per-medicine-rupees': 'Cost per Medicine (₹)', 'days-medicine-lasts': 'Days Medicine Lasts',
263
+ 'total-dosage-per-day-mg': 'Total Dosage per Day (mg)', 'total-pills-given': 'Total Pills Given',
264
+ 'medicine-availability-0-1': 'Medicine Availability (0-1)',
265
+ 'took-medicine-day-1-1yes': 'Took Medicine Day 1 (1=Yes)',
266
+ 'took-medicine-day-2-1yes': 'Took Medicine Day 2 (1=Yes)',
267
+ 'took-medicine-day-3-1yes': 'Took Medicine Day 3 (1=Yes)'
268
+ }
269
+
270
+ missing_features = []
271
+ invalid_features = {}
272
+
273
+ # Get 'Was in ICU' value first for conditional validation
274
+ was_icu_js_key = 'was-in-icu-1yes'
275
+ was_icu_value_str = data.get(was_icu_js_key, '').capitalize()
276
+
277
+ for js_key, feature in js_key_map.items():
278
+ if feature not in feature_order:
279
+ print(f"Warning: Feature '{feature}' from js_key_map not in expected feature_order list.")
280
+ continue
281
+
282
+ is_icu_days = (feature == 'ICU Days')
283
+ is_icu_no = (was_icu_value_str == 'No')
284
+
285
+ # Handle missing values
286
+ if js_key not in data or data[js_key] is None or str(data[js_key]).strip() == '':
287
+ # ICU days is allowed to be missing/empty only if Was in ICU is 'No'
288
+ if is_icu_days and is_icu_no:
289
+ user_input[feature] = 0.0 # Default to 0
290
+ print(f"Setting ICU Days to 0 as Was in ICU is No.")
291
+ continue
292
+ else:
293
+ missing_features.append(feature)
294
+ continue
295
+
296
+ value = data[js_key]
297
+ f_info = feature_info.get(feature)
298
+ if not f_info:
299
+ print(f"Warning: No feature_info found for '{feature}'. Skipping.")
300
+ continue
301
+
302
+ try:
303
+ # Type conversion and mapping
304
+ if f_info['type'] == 'number':
305
+ # Specific check for ICU Days if Was in ICU is Yes
306
+ if is_icu_days and not is_icu_no and float(value) < 0:
307
+ invalid_features[feature] = "ICU Days cannot be negative if patient was in ICU."
308
+ elif is_icu_days and is_icu_no and float(value) != 0:
309
+ # If they entered a non-zero value but said No ICU, force to 0
310
+ user_input[feature] = 0.0
311
+ print(f"Warning: Forcing ICU Days to 0 because Was in ICU is No, but user entered {value}.")
312
+ else:
313
+ user_input[feature] = float(value)
314
+
315
+ elif f_info['type'] == 'select':
316
+ lookup_value = value.capitalize() if feature in ['Was in ICU (1=Yes)',
317
+ 'Took Medicine Day 1 (1=Yes)',
318
+ 'Took Medicine Day 2 (1=Yes)',
319
+ 'Took Medicine Day 3 (1=Yes)'] else value
320
+ mapped_value = None
321
+ if feature == 'Gender': mapped_value = gender_map.get(lookup_value)
322
+ elif feature == 'Why in Hospital': mapped_value = why_map.get(lookup_value)
323
+ elif feature in ['Was in ICU (1=Yes)', 'Took Medicine Day 1 (1=Yes)',
324
+ 'Took Medicine Day 2 (1=Yes)', 'Took Medicine Day 3 (1=Yes)']:
325
+ mapped_value = yes_no_map.get(lookup_value)
326
+
327
+ if mapped_value is None:
328
+ invalid_features[feature] = f"Invalid option: '{value}'"
329
+ else:
330
+ user_input[feature] = mapped_value
331
+ else:
332
+ invalid_features[feature] = f"Unknown feature type: '{f_info['type']}'"
333
+
334
+ except (ValueError, TypeError) as e:
335
+ invalid_features[feature] = f"Invalid format for value '{value}': {e}"
336
+
337
+ # Handle Validation Errors
338
+ if missing_features:
339
+ return jsonify({'success': False, 'error': f"Missing input for: {', '.join(missing_features)}"}), 400
340
+ if invalid_features:
341
+ error_msg = "; ".join([f"{k}: {v}" for k,v in invalid_features.items()])
342
+ return jsonify({'success': False, 'error': f"Invalid input: {error_msg}"}), 400
343
+
344
+ # Ensure all features are present before creating DataFrame
345
+ if len(user_input) != len(feature_order):
346
+ provided = set(user_input.keys())
347
+ expected = set(feature_order)
348
+ missing = list(expected - provided)
349
+ extra = list(provided - expected)
350
+ err = f"Feature mismatch after processing. Missing: {missing}. Extra: {extra}."
351
+ print(f"ERROR: {err}")
352
+ return jsonify({'success': False, 'error': f"Internal error: Feature mismatch. Please check processing logic. Missing: {missing}"}), 500
353
+
354
+ # Create DataFrame in the correct order
355
+ df_user = pd.DataFrame([user_input], columns=feature_order)
356
+ print("Processed DataFrame for prediction:\n", df_user.to_string())
357
+
358
+ # --- Model Predictions ---
359
+ pred_na_raw = model_na.predict(df_user)[0]
360
+ pred_r_raw = model_r.predict(df_user)[0]
361
+ pred_na_score = max(0.0, min(1.0, float(pred_na_raw)))
362
+ pred_r_score = max(0.0, min(1.0, float(pred_r_raw)))
363
+ print(f"Predictions - NA Score: {pred_na_score:.4f}, R Score: {pred_r_score:.4f}")
364
+ risk_level_na = get_risk_level(pred_na_score)
365
+ risk_level_r = get_risk_level(pred_r_score)
366
+
367
+ # --- SHAP Explanations ---
368
+ # (Keep the SHAP calculation block largely the same as previous version,
369
+ # ensuring base_value_na/r and shap_error_na/r are set correctly)
370
+ shap_explainer_na = get_shap_explainer('non_adherence', model_na)
371
+ shap_explainer_r = get_shap_explainer('readmission', model_r)
372
+ shap_data_na = []
373
+ shap_data_r = []
374
+ base_value_na = None
375
+ base_value_r = None
376
+ shap_error_na = False
377
+ shap_error_r = False
378
+
379
+ # Calculate SHAP NA
380
+ if shap_explainer_na:
381
+ try:
382
+ shap_values_na = shap_explainer_na.shap_values(df_user)
383
+ shap_vec_na = None
384
+ if isinstance(shap_values_na, list): shap_vec_na = shap_values_na[0][0] # Multi-output?
385
+ elif isinstance(shap_values_na, np.ndarray) and shap_values_na.ndim == 2: shap_vec_na = shap_values_na[0]
386
+ elif isinstance(shap_values_na, np.ndarray) and shap_values_na.ndim == 1: shap_vec_na = shap_values_na
387
+ else: raise TypeError(f"Unexpected SHAP NA format: {type(shap_values_na)}")
388
+
389
+ ev_na = shap_explainer_na.expected_value
390
+ if isinstance(ev_na, (list, np.ndarray)): base_value_na = float(ev_na[0])
391
+ elif ev_na is not None: base_value_na = float(ev_na)
392
+ else: base_value_na = 0.5; print("Warning: SHAP NA expected_value is None. Using 0.5.")
393
+
394
+ if shap_vec_na is not None and len(shap_vec_na) == len(feature_order):
395
+ print(f"SHAP NA: Base={base_value_na:.4f}, Sum={np.sum(shap_vec_na):.4f}, Pred={pred_na_score:.4f}, Total={base_value_na + np.sum(shap_vec_na):.4f}")
396
+ for i, feature in enumerate(feature_order):
397
+ orig_js_key = next((k for k, v in js_key_map.items() if v == feature), None)
398
+ orig_val = data.get(orig_js_key, "N/A")
399
+ f_info = feature_info.get(feature, {})
400
+ shap_data_na.append({
401
+ 'feature': feature, 'shap_value': float(shap_vec_na[i]),
402
+ 'feature_value': str(orig_val), 'numeric_value': float(df_user.iloc[0, i]),
403
+ 'description': f_info.get('description', ''), 'help_text': f_info.get('help_text', '')
404
+ })
405
+ shap_data_na.sort(key=lambda x: abs(x.get('shap_value', 0)), reverse=True)
406
+ else: raise ValueError(f"SHAP NA vector length mismatch or None.")
407
+ except Exception as e:
408
+ print(f"Error calculating SHAP NA: {e}"); traceback.print_exc()
409
+ shap_data_na = [{"error": "Could not calculate SHAP values for Non-Adherence."}]
410
+ base_value_na = 0.5; shap_error_na = True
411
+ else:
412
+ shap_data_na = [{"error": "SHAP explainer NA not available."}]; base_value_na = 0.5; shap_error_na = True
413
+
414
+ # Calculate SHAP R
415
+ if shap_explainer_r:
416
+ try:
417
+ shap_values_r = shap_explainer_r.shap_values(df_user)
418
+ shap_vec_r = None
419
+ if isinstance(shap_values_r, list): shap_vec_r = shap_values_r[0][0]
420
+ elif isinstance(shap_values_r, np.ndarray) and shap_values_r.ndim == 2: shap_vec_r = shap_values_r[0]
421
+ elif isinstance(shap_values_r, np.ndarray) and shap_values_r.ndim == 1: shap_vec_r = shap_values_r
422
+ else: raise TypeError(f"Unexpected SHAP R format: {type(shap_values_r)}")
423
+
424
+ ev_r = shap_explainer_r.expected_value
425
+ if isinstance(ev_r, (list, np.ndarray)): base_value_r = float(ev_r[0])
426
+ elif ev_r is not None: base_value_r = float(ev_r)
427
+ else: base_value_r = 0.5; print("Warning: SHAP R expected_value is None. Using 0.5.")
428
+
429
+ if shap_vec_r is not None and len(shap_vec_r) == len(feature_order):
430
+ print(f"SHAP R: Base={base_value_r:.4f}, Sum={np.sum(shap_vec_r):.4f}, Pred={pred_r_score:.4f}, Total={base_value_r + np.sum(shap_vec_r):.4f}")
431
+ for i, feature in enumerate(feature_order):
432
+ orig_js_key = next((k for k, v in js_key_map.items() if v == feature), None)
433
+ orig_val = data.get(orig_js_key, "N/A")
434
+ f_info = feature_info.get(feature, {})
435
+ shap_data_r.append({
436
+ 'feature': feature, 'shap_value': float(shap_vec_r[i]),
437
+ 'feature_value': str(orig_val), 'numeric_value': float(df_user.iloc[0, i]),
438
+ 'description': f_info.get('description', ''), 'help_text': f_info.get('help_text', '')
439
+ })
440
+ shap_data_r.sort(key=lambda x: abs(x.get('shap_value', 0)), reverse=True)
441
+ else: raise ValueError(f"SHAP R vector length mismatch or None.")
442
+ except Exception as e:
443
+ print(f"Error calculating SHAP R: {e}"); traceback.print_exc()
444
+ shap_data_r = [{"error": "Could not calculate SHAP values for Readmission."}]
445
+ base_value_r = 0.5; shap_error_r = True
446
+ else:
447
+ shap_data_r = [{"error": "SHAP explainer R not available."}]; base_value_r = 0.5; shap_error_r = True
448
+
449
+ # --- Counterfactuals & Recommendations ---
450
+ cf_data_na = generate_comprehensive_counterfactuals(df_user, model_na, "Non-Adherence", shap_data_na, shap_error_na)
451
+ cf_data_r = generate_comprehensive_counterfactuals(df_user, model_r, "Readmission", shap_data_r, shap_error_r)
452
+ recommendations = generate_recommendations(shap_data_na, shap_data_r, pred_na_score, pred_r_score, shap_error_na, shap_error_r)
453
+
454
+ # --- Visualizations ---
455
+ gauges = generate_gauge_charts(pred_na_score, pred_r_score)
456
+ additional_visualizations = generate_additional_visualizations(
457
+ df_user, shap_data_na, shap_data_r, base_value_na, base_value_r, shap_error_na, shap_error_r
458
+ )
459
+
460
+ # --- Response ---
461
+ response = {
462
+ 'success': True,
463
+ 'predictions': {
464
+ 'non_adherence': round(pred_na_score, 3),
465
+ 'readmission': round(pred_r_score, 3),
466
+ 'risk_level_na': risk_level_na,
467
+ 'risk_level_r': risk_level_r,
468
+ 'base_value_na': round(base_value_na, 3) if base_value_na is not None else None,
469
+ 'base_value_r': round(base_value_r, 3) if base_value_r is not None else None
470
+ },
471
+ 'explanations': {
472
+ 'shap_values_na': shap_data_na,
473
+ 'shap_values_r': shap_data_r,
474
+ 'counterfactuals_na': cf_data_na,
475
+ 'counterfactuals_r': cf_data_r,
476
+ 'shap_error_na': shap_error_na,
477
+ 'shap_error_r': shap_error_r
478
+ },
479
+ 'recommendations': recommendations,
480
+ 'visualizations': {
481
+ 'gauges': gauges,
482
+ 'additional': additional_visualizations
483
+ }
484
+ }
485
+ return jsonify(response)
486
+
487
+ # Error Handling remains the same as previous version
488
+ except ValueError as ve:
489
+ print(f"Value Error during prediction processing: {ve}")
490
+ traceback.print_exc()
491
+ return jsonify({'success': False, 'error': f"Invalid input: {str(ve)}"}), 400
492
+ except KeyError as ke:
493
+ print(f"Key Error during prediction processing: {ke}")
494
+ traceback.print_exc()
495
+ return jsonify({'success': False, 'error': f"Missing expected data field: {str(ke)}"}), 400
496
+ except Exception as e:
497
+ print(f"An unexpected error occurred during prediction: {e}")
498
+ traceback.print_exc()
499
+ return jsonify({'success': False, 'error': "An internal server error occurred. Please try again later."}), 500
500
+
501
+
502
+ # --- Placeholder JSON for Plots on Error ---
503
+ # (Keep create_placeholder_plot as it was)
504
+ def create_placeholder_plot(title_suffix, message="Required data not available for this chart."):
505
+ """Creates JSON for a placeholder plot indicating data unavailability."""
506
+ return json.dumps({
507
+ 'data': [],
508
+ 'layout': {
509
+ 'title': f'{title_suffix} (Not Generated)',
510
+ 'xaxis': {'visible': False},
511
+ 'yaxis': {'visible': False},
512
+ 'annotations': [{
513
+ 'text': message,
514
+ 'xref': 'paper', 'yref': 'paper',
515
+ 'x': 0.5, 'y': 0.5, 'showarrow': False,
516
+ 'font': {'size': 12, 'color': '#888'}
517
+ }],
518
+ 'plot_bgcolor': 'rgba(0,0,0,0)',
519
+ 'paper_bgcolor': 'rgba(0,0,0,0)'
520
+ }
521
+ }, cls=plotly.utils.PlotlyJSONEncoder)
522
+
523
+ # --- Counterfactual Generation ---
524
+ # (Keep generate_comprehensive_counterfactuals as it was)
525
+ def generate_comprehensive_counterfactuals(df_original, model, target_type, shap_data, shap_error):
526
+ """Generate simplified counterfactuals based on SHAP values and feature modifiability."""
527
+ print(f"Generating counterfactuals for {target_type}...")
528
+ counterfactuals = []
529
+ if shap_error or not isinstance(shap_data, list) or not shap_data or (shap_data[0] and shap_data[0].get("error")):
530
+ print(f"Skipping counterfactuals for {target_type} due to SHAP errors or missing data.")
531
+ return [{"error": f"Could not generate counterfactuals for {target_type} because SHAP data is missing or invalid."}]
532
+
533
+ df = df_original.copy()
534
+ try:
535
+ current_pred_score = max(0.0, min(1.0, float(model.predict(df)[0])))
536
+ goal_direction = "decrease"
537
+ non_modifiable_features = {"Age", "Gender", "Why in Hospital", "Hospital Days", "Was in ICU (1=Yes)", "ICU Days"}
538
+ valid_shap_data = [item for item in shap_data if "error" not in item]
539
+ shap_data_sorted = sorted(valid_shap_data, key=lambda x: abs(x.get('shap_value', 0)), reverse=True)
540
+ cf_count = 0
541
+ max_cfs = 5
542
+
543
+ for feature_data in shap_data_sorted:
544
+ if cf_count >= max_cfs: break
545
+ feature = feature_data.get('feature')
546
+ shap_value = feature_data.get('shap_value', 0)
547
+ current_display_value = feature_data.get('feature_value', 'N/A')
548
+ original_numeric_value = feature_data.get('numeric_value')
549
+
550
+ if feature is None or original_numeric_value is None or abs(shap_value) < 0.01 or feature in non_modifiable_features:
551
+ continue
552
+
553
+ feature_change_direction = "decrease" if shap_value > 0 else "increase"
554
+ could_change = False
555
+ new_value = None
556
+ suggested_val_str = ""
557
+ potential_outcome = "N/A"
558
+ notes = ""
559
+ impact_magnitude = "Minor"
560
+
561
+ # --- Specific Logic for Modifiable Features ---
562
+ if feature.startswith("Took Medicine Day"):
563
+ if original_numeric_value == 0 and shap_value > 0.01:
564
+ suggested_val_str = "Ensure Adherence ('Yes')"
565
+ new_value = 1
566
+ could_change = True
567
+ elif feature == "Medicine Availability (0-1)":
568
+ if shap_value < -0.01 and original_numeric_value < 0.95:
569
+ new_value = 1.0
570
+ suggested_val_str = "Improve towards High (1.0)"
571
+ could_change = True
572
+ else: # Other numeric modifiable features
573
+ if (feature_change_direction == "decrease" and shap_value > 0.02) or \
574
+ (feature_change_direction == "increase" and shap_value < -0.02):
575
+ change_perc = 0.20
576
+ target_factor = (1 - change_perc) if feature_change_direction == "decrease" else (1 + change_perc)
577
+ tentative_new_value = original_numeric_value * target_factor
578
+
579
+ # Apply constraints and rounding
580
+ if feature in ["Number of Medicines", "Total Pills Given"]: new_value = max(1, math.floor(tentative_new_value))
581
+ elif feature == "Days Medicine Lasts": new_value = max(7, math.floor(tentative_new_value))
582
+ elif feature == "Cost per Medicine (₹)": new_value = max(10.0, round(tentative_new_value, 0))
583
+ elif feature == "Total Dosage per Day (mg)": new_value = max(5.0, round(tentative_new_value, 0))
584
+ else: new_value = round(tentative_new_value, 2) # Fallback
585
+
586
+ # Generate suggested string AFTER setting new_value
587
+ if feature == "Number of Medicines": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{int(new_value)}"
588
+ elif feature == "Total Pills Given": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{int(new_value)}"
589
+ elif feature == "Days Medicine Lasts": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{int(new_value)} days"
590
+ elif feature == "Cost per Medicine (₹)": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~₹{new_value:.0f}"
591
+ elif feature == "Total Dosage per Day (mg)": suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{new_value:.0f} mg"
592
+ else: suggested_val_str = f"{feature_change_direction.capitalize()} towards ~{new_value:.2f}"
593
+
594
+
595
+ # Check if change is significant
596
+ if abs(new_value - original_numeric_value) > 0.01 * abs(original_numeric_value) + 0.1:
597
+ could_change = True
598
+ else:
599
+ new_value = None # Don't simulate small changes
600
+
601
+ # --- Simulate Outcome ---
602
+ if could_change and new_value is not None and suggested_val_str:
603
+ df_cf = df.copy(); df_cf[feature] = new_value
604
+ try:
605
+ cf_pred_score = max(0.0, min(1.0, float(model.predict(df_cf)[0])))
606
+ change_in_score = cf_pred_score - current_pred_score
607
+ outcome_desc = f"Est. new risk: {cf_pred_score:.1%}"
608
+ change_desc = f"({change_in_score:+.1%})"
609
+
610
+ if (goal_direction == "decrease" and change_in_score < -0.01):
611
+ potential_outcome = f"{outcome_desc} {change_desc}"
612
+ if abs(change_in_score) > 0.10: impact_magnitude = "Significant"
613
+ elif abs(change_in_score) > 0.05: impact_magnitude = "Moderate"
614
+ notes = f"This change is predicted to {goal_direction} the {target_type.lower()} risk."
615
+ counterfactuals.append({
616
+ "feature": feature, "current_value": current_display_value,
617
+ "suggested_change": suggested_val_str, "potential_outcome": potential_outcome,
618
+ "impact_magnitude": impact_magnitude, "risk_type": target_type
619
+ })
620
+ cf_count += 1
621
+ except Exception as sim_e:
622
+ print(f"Error simulating counterfactual for {feature}: {sim_e}")
623
+
624
+ if not counterfactuals:
625
+ counterfactuals.append({"notes": f"No simple, impactful counterfactual changes identified among top modifiable factors for {target_type} risk."})
626
+ print(f"Finished generating {cf_count} counterfactual entries for {target_type}.")
627
+ return counterfactuals
628
+ except Exception as e:
629
+ print(f"Error generating counterfactuals for {target_type}: {str(e)}")
630
+ traceback.print_exc()
631
+ return [{"error": f"Could not generate counterfactuals for {target_type}: {str(e)}"}]
632
+
633
+ # --- Recommendation Generation ---
634
+ # (Keep generate_recommendations as it was)
635
+ def generate_recommendations(shap_data_na, shap_data_r, pred_na_prob, pred_r_prob, shap_error_na, shap_error_r):
636
+ """Generate actionable recommendations based on risk levels and SHAP factors."""
637
+ print("Generating recommendations...")
638
+ recommendations = []
639
+ processed_features = set()
640
+
641
+ na_shap_valid = not shap_error_na and isinstance(shap_data_na, list) and shap_data_na and "error" not in shap_data_na[0]
642
+ r_shap_valid = not shap_error_r and isinstance(shap_data_r, list) and shap_data_r and "error" not in shap_data_r[0]
643
+
644
+ na_risk_level_info = get_risk_level(pred_na_prob)
645
+ r_risk_level_info = get_risk_level(pred_r_prob)
646
+ na_level = na_risk_level_info['level']
647
+ r_level = r_risk_level_info['level']
648
+
649
+ # --- General Recommendations ---
650
+ if na_level == 'High':
651
+ recommendations.append({
652
+ "category": "Overall Non-Adherence", "priority": "Critical",
653
+ "recommendation": f"High ({na_risk_level_info['percentage']}) non-adherence risk detected.",
654
+ "action": "Implement intensive adherence support: daily reminders, regimen simplification consult, frequent follow-up (e.g., within 3 days), assess/address specific barriers (cost, access, understanding)."
655
+ })
656
+ elif na_level == 'Medium':
657
+ recommendations.append({
658
+ "category": "Overall Non-Adherence", "priority": "High",
659
+ "recommendation": f"Medium ({na_risk_level_info['percentage']}) non-adherence risk detected.",
660
+ "action": "Provide adherence aids (e.g., pillbox, app reminders), schedule follow-up call within 1 week, reinforce importance of medication."
661
+ })
662
+ if r_level == 'High':
663
+ recommendations.append({
664
+ "category": "Overall Readmission", "priority": "Critical",
665
+ "recommendation": f"High ({r_risk_level_info['percentage']}) readmission risk detected.",
666
+ "action": "Comprehensive discharge plan crucial: schedule follow-up within 7 days, ensure medication reconciliation, consider home health/transitional care referral, detailed patient/caregiver education using teach-back."
667
+ })
668
+ elif r_level == 'Medium':
669
+ recommendations.append({
670
+ "category": "Overall Readmission", "priority": "High",
671
+ "recommendation": f"Medium ({r_risk_level_info['percentage']}) readmission risk detected.",
672
+ "action": "Enhanced discharge process: ensure clear instructions (written/verbal), schedule follow-up within 10-14 days, confirm patient understanding of red flags and who to contact."
673
+ })
674
+
675
+ # --- Specific Recommendations ---
676
+ if not na_shap_valid and not r_shap_valid:
677
+ recommendations.append({
678
+ "category": "Explanations", "priority": "Medium",
679
+ "recommendation": "Detailed factor analysis unavailable due to SHAP errors.",
680
+ "action": "Focus on general risk levels and standard protocols. Investigate SHAP calculation issues if persistent."})
681
+ return recommendations # Early exit
682
+
683
+ combined_shap = {}
684
+ valid_shap_list = ([item for item in shap_data_na if "error" not in item] if na_shap_valid else []) + \
685
+ ([item for item in shap_data_r if "error" not in item] if r_shap_valid else [])
686
+ for item in valid_shap_list:
687
+ feature = item.get('feature')
688
+ abs_shap = abs(item.get('shap_value', 0))
689
+ if feature:
690
+ current_entry = combined_shap.get(feature, {'abs_shap_sum': 0, 'data_na': None, 'data_r': None})
691
+ current_entry['abs_shap_sum'] += abs_shap
692
+ if item in shap_data_na: current_entry['data_na'] = item
693
+ if item in shap_data_r: current_entry['data_r'] = item
694
+ combined_shap[feature] = current_entry
695
+ sorted_features = sorted(combined_shap.keys(), key=lambda f: combined_shap[f]['abs_shap_sum'], reverse=True)
696
+
697
+ rec_count = 0
698
+ max_recs = 7
699
+ shap_threshold = 0.02 # Minimum SHAP value to consider as a driver
700
+
701
+ for feature in sorted_features:
702
+ if rec_count >= max_recs: break
703
+ if feature in processed_features: continue
704
+
705
+ shap_entry = combined_shap.get(feature, {})
706
+ na_item = shap_entry.get('data_na')
707
+ r_item = shap_entry.get('data_r')
708
+ na_shap = na_item.get('shap_value', 0) if na_item else 0
709
+ r_shap = r_item.get('shap_value', 0) if r_item else 0
710
+ item_for_val = na_item or r_item
711
+ current_val = item_for_val.get('feature_value', 'N/A') if item_for_val else 'N/A'
712
+ numeric_val = item_for_val.get('numeric_value') if item_for_val else None
713
+
714
+ is_na_driver = na_shap > shap_threshold
715
+ is_r_driver = r_shap > shap_threshold
716
+ if not (is_na_driver or is_r_driver): continue # Skip if not a driver for either
717
+
718
+ max_abs_shap = max(abs(na_shap), abs(r_shap))
719
+ priority = "Medium"
720
+ if max_abs_shap > 0.1: priority = "High"
721
+ if (is_na_driver and na_level == 'High') or (is_r_driver and r_level == 'High'): priority = "High" if priority == "Medium" else priority
722
+ if (na_level == 'High' and na_shap > 0.15) or (r_level == 'High' and r_shap > 0.15): priority = "Critical"
723
+
724
+ rec_made = False
725
+ action_text = ""
726
+ rec_category = ""
727
+ rec_recommendation = ""
728
+
729
+ # --- Generate Recommendation Content ---
730
+ # (Combine NA and R logic for the same feature where applicable)
731
+ if feature.startswith("Took Medicine Day") and current_val == "No":
732
+ if is_na_driver: # Primarily an adherence issue
733
+ priority = "Critical"; rec_category = "Early Adherence Failure"
734
+ rec_recommendation = f"Missed medication on {feature.split('(')[0].strip()} is a strong indicator of future non-adherence."
735
+ action_text = "Immediate intervention required: follow-up call TODAY, assess reasons, counsel, establish reminders/support."
736
+ rec_made = True
737
+ elif feature == "Number of Medicines":
738
+ if is_na_driver or is_r_driver:
739
+ rec_category = "Medication Complexity"
740
+ rec_recommendation = f"High number of medicines ({current_val}) associated with increased risk ({'NA' if is_na_driver else ''}{'& R' if is_r_driver else ''})."
741
+ action_text = "Review list for simplification/consolidation. Consider pharmacist consult for polypharmacy review."
742
+ rec_made = True
743
+ elif feature == "Cost per Medicine (₹)":
744
+ try: cost_val = float(numeric_val) if numeric_val is not None else 0
745
+ except: cost_val = 0
746
+ if cost_val > 100 and is_na_driver: # Primarily adherence cost barrier
747
+ rec_category = "Medication Cost Barrier"
748
+ rec_recommendation = f"High average medication cost (₹{current_val}) may be a barrier to adherence."
749
+ action_text = "Discuss cost concerns. Explore generics, assistance programs, or lower-cost alternatives."
750
+ rec_made = True
751
+ elif feature == "Medicine Availability (0-1)":
752
+ try: avail_val = float(numeric_val) if numeric_val is not None else 1.0
753
+ except: avail_val = 1.0
754
+ if avail_val < 0.5 and is_na_driver: # Primarily adherence access issue
755
+ rec_category = "Medication Access Issue"
756
+ rec_recommendation = f"Reported low medication availability ({current_val}) likely hinders adherence."
757
+ action_text = "Verify pharmacy stock pre-discharge. Help patient identify reliable source or discuss alternatives."
758
+ rec_made = True
759
+ elif feature == "ICU Days":
760
+ try: icu_days_val = int(numeric_val) if numeric_val is not None else 0
761
+ except: icu_days_val = 0
762
+ if icu_days_val > 0 and is_r_driver: # Primarily readmission risk
763
+ rec_priority = "High" if icu_days_val > 2 else priority
764
+ rec_category = "ICU History Impact"; priority = rec_priority
765
+ rec_recommendation = f"Prior ICU stay ({current_val} days) significantly increases readmission risk."
766
+ action_text = "Intensive post-discharge support: early follow-up (≤7 days), consider transitional care/home health, meticulous med review & education."
767
+ rec_made = True
768
+ elif feature == "Hospital Days":
769
+ try: days_val = int(numeric_val) if numeric_val is not None else 0
770
+ except: days_val = 0
771
+ if days_val > 7 and is_r_driver: # Primarily readmission risk
772
+ rec_category = "Length of Stay Impact"
773
+ rec_recommendation = f"Longer hospital stay ({current_val} days) associated with increased readmission risk."
774
+ action_text = "Reinforces need for thorough discharge planning, clear instructions, med reconciliation, and prompt follow-up (≤10 days)."
775
+ rec_made = True
776
+ elif feature == "Age":
777
+ try: age_val = int(numeric_val) if numeric_val is not None else 0
778
+ except: age_val = 0
779
+ if age_val > 75 and is_r_driver: # Readmission context factor
780
+ priority = "Medium"; rec_category = "Age Factor (Context)"
781
+ rec_recommendation = f"Patient's age ({current_val}) contributes moderately to readmission risk."
782
+ action_text = "Consider age-related needs in discharge plan (support, mobility, cognition, simple instructions)."
783
+ rec_made = True
784
+ elif feature == "Why in Hospital":
785
+ if is_r_driver: # Readmission context factor
786
+ priority = "Medium" if priority=="Standard" else priority # Elevate slightly but not critical usually
787
+ rec_category = "Diagnosis Factor (Context)"
788
+ rec_recommendation = f"Primary diagnosis ({current_val}) contributes to risk."
789
+ action_text = f"Ensure condition-specific discharge education (red flags, follow-up) and care plan for managing '{current_val.lower()}' are emphasized."
790
+ rec_made = True
791
+
792
+ # Append the recommendation if one was generated
793
+ if rec_made:
794
+ recommendations.append({
795
+ "category": rec_category,
796
+ "priority": priority,
797
+ "recommendation": rec_recommendation,
798
+ "action": action_text
799
+ })
800
+ processed_features.add(feature)
801
+ rec_count += 1
802
+
803
+ # Add default recommendation if few specific ones generated but risk elevated
804
+ if (na_level in ['High', 'Medium'] or r_level in ['High', 'Medium']) and len(recommendations) < 3:
805
+ recommendations.append({
806
+ "category": "General Follow-up", "priority": "Medium",
807
+ "recommendation": "Overall risk is elevated. Review standard discharge protocols.",
808
+ "action": "Ensure robust basics: use teach-back, confirm follow-up appointments, provide clear contact info."})
809
+
810
+ # Sort final recommendations by priority
811
+ priority_map = {"Critical": 0, "High": 1, "Medium": 2, "Standard": 3, "Info": 4}
812
+ recommendations.sort(key=lambda x: priority_map.get(x.get("priority", "Medium"), 99))
813
+
814
+ print(f"Finished generating {len(recommendations)} recommendations.")
815
+ return recommendations
816
+
817
+
818
+ # --- Visualization Generation ---
819
+ # (Keep generate_gauge_charts as it was)
820
+ def generate_gauge_charts(pred_na_prob, pred_r_prob):
821
+ """Generate Plotly gauge chart JSON objects."""
822
+ print("Generating gauge charts...")
823
+ gauges = {}
824
+ color_low = '#28a745'; color_medium = '#ffc107'; color_high = '#dc3545'
825
+ color_low_bg = '#d4edda'; color_medium_bg = '#fff3cd'; color_high_bg = '#f8d7da'
826
+ pred_na_prob = max(0.0, min(1.0, pred_na_prob))
827
+ pred_r_prob = max(0.0, min(1.0, pred_r_prob))
828
+ common_layout = {'height': 220, 'margin': {'t': 50, 'b': 10, 'l': 20, 'r': 20}, 'font': {'color': "#333", 'family': "Arial, sans-serif", 'size': 12}, 'paper_bgcolor': 'rgba(0,0,0,0)', 'plot_bgcolor': 'rgba(0,0,0,0)', 'autosize': True}
829
+ common_gauge = {'axis': {'range': [0, 100], 'tickwidth': 1, 'tickcolor': "#aaa", 'tickfont': {'size': 10}}, 'bar': {'color': "rgba(0,0,0,0.1)", 'thickness': 0.3}, 'bgcolor': 'rgba(0,0,0,0)', 'borderwidth': 0, 'steps': [{'range': [0, 30], 'color': color_low_bg}, {'range': [30, 70], 'color': color_medium_bg}, {'range': [70, 100], 'color': color_high_bg}], 'threshold': {'line': {'color': '#666', 'width': 4}, 'thickness': 0.75, 'value': 0}}
830
+
831
+ na_risk_info = get_risk_level(pred_na_prob)
832
+ na_value_perc = round(pred_na_prob * 100)
833
+ na_gauge_data = common_gauge.copy(); na_gauge_data['threshold']['value'] = na_value_perc
834
+ gauge_bar_color_na = {'green': color_low, 'yellow': color_medium, 'red': color_high}.get(na_risk_info['color'], '#888')
835
+ na_gauge_data['bar']['color'] = gauge_bar_color_na
836
+ gauges['non_adherence'] = json.dumps({'data': [{'type': 'indicator', 'mode': 'gauge+number', 'value': na_value_perc, 'title': {'text': f"<b>Non-Adherence Risk</b><br><span style='font-size:0.9em;'>Level: {na_risk_info['level']}</span>", 'font': {'size': 14}}, 'gauge': na_gauge_data, 'number': {'suffix': "%", 'font': {'size': 24, 'color': gauge_bar_color_na}}}], 'layout': common_layout.copy()}, cls=plotly.utils.PlotlyJSONEncoder)
837
+
838
+ r_risk_info = get_risk_level(pred_r_prob)
839
+ r_value_perc = round(pred_r_prob * 100)
840
+ r_gauge_data = common_gauge.copy(); r_gauge_data['threshold']['value'] = r_value_perc
841
+ gauge_bar_color_r = {'green': color_low, 'yellow': color_medium, 'red': color_high}.get(r_risk_info['color'], '#888')
842
+ r_gauge_data['bar']['color'] = gauge_bar_color_r
843
+ gauges['readmission'] = json.dumps({'data': [{'type': 'indicator', 'mode': 'gauge+number', 'value': r_value_perc, 'title': {'text': f"<b>Readmission Risk</b><br><span style='font-size:0.9em;'>Level: {r_risk_info['level']}</span>", 'font': {'size': 14}}, 'gauge': r_gauge_data, 'number': {'suffix': "%", 'font': {'size': 24, 'color': gauge_bar_color_r}}}], 'layout': common_layout.copy()}, cls=plotly.utils.PlotlyJSONEncoder)
844
+
845
+ print("Finished generating gauge charts.")
846
+ return gauges
847
+
848
+
849
+ # In app.py
850
+ def generate_additional_visualizations(df_user, shap_data_na, shap_data_r,
851
+ base_value_na, base_value_r,
852
+ shap_error_na, shap_error_r):
853
+ """Generate additional Plotly JSON for various XAI visualizations, including
854
+ waterfall, heatmap, feature comparison, intervention impact, and network graph."""
855
+ print("Generating additional visualizations...")
856
+ visualizations = {}
857
+ plot_bgcolor = 'rgba(0,0,0,0)'
858
+ paper_bgcolor = 'rgba(0,0,0,0)'
859
+ font_family = "Arial, sans-serif"
860
+ font_color = "#333"
861
+
862
+ # Determine validity
863
+ na_shap_valid = not shap_error_na and isinstance(shap_data_na, list) and shap_data_na and "error" not in shap_data_na[0]
864
+ r_shap_valid = not shap_error_r and isinstance(shap_data_r, list) and shap_data_r and "error" not in shap_data_r[0]
865
+
866
+ # --- 1. Waterfall Charts ---
867
+ def create_waterfall(shap_data, base_value, title_suffix, is_valid):
868
+ if not is_valid or base_value is None:
869
+ return create_placeholder_plot(f'Waterfall: {title_suffix}')
870
+ try:
871
+ valid = [item for item in shap_data if isinstance(item, dict) and "error" not in item]
872
+ top = sorted(valid, key=lambda x: abs(x['shap_value']), reverse=True)[:10]
873
+ values = [item['shap_value'] for item in top]
874
+ labels = [
875
+ f"{item['feature'][:25]}{'...' if len(item['feature'])>25 else ''} = {item['feature_value']}"
876
+ for item in top
877
+ ]
878
+ measures = ["absolute"] + ["relative"]*len(top) + ["total"]
879
+ y = ["Average Model Output"] + labels + ["Final Prediction"]
880
+ x = [base_value] + values + [base_value + sum(values)]
881
+ text = [f"{v:+.3f}" if 0< i < len(x)-1 else f"{v:.3f}" for i, v in enumerate(x)]
882
+
883
+ fig = go.Figure(go.Waterfall(
884
+ orientation="h", measure=measures, y=y, x=x, text=text,
885
+ textposition="outside", base=0,
886
+ connector={"line":{"color":"rgb(150,150,150)", "width":1}},
887
+ increasing={"marker":{"color":"#dc3545","line":{"width":1,"color":"#dc3545"}}},
888
+ decreasing={"marker":{"color":"#28a745","line":{"width":1,"color":"#28a745"}}},
889
+ totals={"marker":{"color":"#007bff","line":{"width":1,"color":"#007bff"}}}
890
+ ))
891
+ # autoscale axis
892
+ all_vals = x
893
+ mn, mx = min(all_vals), max(all_vals)
894
+ pad = (mx-mn)*0.15 if (mx-mn)>0.01 else 0.1
895
+ fig.update_layout(
896
+ title=f'How Factors Contribute to {title_suffix}',
897
+ showlegend=False,
898
+ height=max(450,40*len(y)),
899
+ margin=dict(t=50,l=250,r=50,b=50),
900
+ yaxis={'autorange':'reversed','automargin':True},
901
+ xaxis={'range':[mn-pad,mx+pad],'title':'Contribution to Score'},
902
+ plot_bgcolor=plot_bgcolor,
903
+ paper_bgcolor=paper_bgcolor,
904
+ font=dict(family=font_family,color=font_color),
905
+ autosize=True
906
+ )
907
+ return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
908
+ except Exception as e:
909
+ print(f"Error creating waterfall {title_suffix}: {e}")
910
+ return create_placeholder_plot(f'Waterfall: {title_suffix}')
911
+
912
+ visualizations['waterfall_na'] = create_waterfall(shap_data_na, base_value_na, 'Non-Adherence Risk', na_shap_valid)
913
+ visualizations['waterfall_r'] = create_waterfall(shap_data_r, base_value_r, 'Readmission Risk', r_shap_valid)
914
+
915
+ # --- 2. Combined SHAP for heatmap & bar chart ---
916
+ # collect valid shap items
917
+ valid_na = [i for i in shap_data_na if na_shap_valid and "error" not in i]
918
+ valid_r = [i for i in shap_data_r if r_shap_valid and "error" not in i]
919
+ combined_impact = {}
920
+ for i in (valid_na + valid_r):
921
+ f=i['feature']; combined_impact[f]=combined_impact.get(f,0)+abs(i['shap_value'])
922
+ top_feats = sorted(combined_impact, key=lambda f:combined_impact[f], reverse=True)[:10]
923
+
924
+ # 2a. Heatmap
925
+ try:
926
+ z=[]; text=[]
927
+ for f in top_feats:
928
+ na_v = next((i['shap_value'] for i in valid_na if i['feature']==f),0)
929
+ r_v = next((i['shap_value'] for i in valid_r if i['feature']==f),0)
930
+ z.append([na_v,r_v])
931
+ fv_na = next((i['feature_value'] for i in valid_na if i['feature']==f),None)
932
+ fv_r = next((i['feature_value'] for i in valid_r if i['feature']==f),None)
933
+ display = fv_na if fv_na is not None else fv_r
934
+ text.append([f"<b>{f}={display}</b><br>NA: {na_v:.3f}", f"<b>{f}={display}</b><br>R: {r_v:.3f}"])
935
+ hm = {
936
+ 'data':[{
937
+ 'type':'heatmap','z':z,'text':text,'hoverinfo':'text',
938
+ 'x':['Non-Adherence','Readmission'],'y':top_feats,
939
+ 'colorscale':'RdBu_r','zmid':0,'xgap':1,'ygap':1
940
+ }],
941
+ 'layout':{
942
+ 'title':'Top Factor Impact Comparison',
943
+ 'height':max(400,40*len(top_feats)),
944
+ 'margin':{'t':60,'l':250,'b':80,'r':50},
945
+ 'yaxis':{'autorange':'reversed','automargin':True},
946
+ 'plot_bgcolor':plot_bgcolor,'paper_bgcolor':paper_bgcolor,
947
+ 'font':{'family':font_family,'color':font_color},
948
+ 'autosize':True
949
+ }
950
+ }
951
+ visualizations['risk_heatmap']=json.dumps(hm, cls=plotly.utils.PlotlyJSONEncoder)
952
+ except Exception as e:
953
+ print(f"Error heatmap: {e}")
954
+ visualizations['risk_heatmap']=create_placeholder_plot('Risk Factor Heatmap')
955
+
956
+ # 2b. Bar chart
957
+ try:
958
+ bars = {
959
+ 'data':[
960
+ {'type':'bar','x':top_feats,'y':[abs(next((i['shap_value'] for i in valid_na if i['feature']==f),0)) for f in top_feats],
961
+ 'name':'NA Impact'},
962
+ {'type':'bar','x':top_feats,'y':[abs(next((i['shap_value'] for i in valid_r if i['feature']==f),0)) for f in top_feats],
963
+ 'name':'Readmission Impact'}
964
+ ],
965
+ 'layout':{
966
+ 'title':'Feature Importance (Absolute SHAP)',
967
+ 'barmode':'group','bargap':0.15,'bargroupgap':0.1,
968
+ 'height':450,'margin':{'t':50,'b':150,'l':60,'r':20},
969
+ 'xaxis':{'tickangle':-45,'automargin':True},
970
+ 'plot_bgcolor':plot_bgcolor,'paper_bgcolor':paper_bgcolor,
971
+ 'font':{'family':font_family,'color':font_color},
972
+ 'autosize':True
973
+ }
974
+ }
975
+ visualizations['feature_comparison']=json.dumps(bars, cls=plotly.utils.PlotlyJSONEncoder)
976
+ except Exception as e:
977
+ print(f"Error bar chart: {e}")
978
+ visualizations['feature_comparison']=create_placeholder_plot('Feature Importance Comparison')
979
+
980
+ # --- 3. Intervention Impact ---
981
+ try:
982
+ interventions=[]
983
+ mod_set={'Number of Medicines','Cost per Medicine (₹)','Days Medicine Lasts',
984
+ 'Total Dosage per Day (mg)','Total Pills Given','Medicine Availability (0-1)',
985
+ 'Took Medicine Day 1 (1=Yes)','Took Medicine Day 2 (1=Yes)',
986
+ 'Took Medicine Day 3 (1=Yes)'}
987
+ map_label={
988
+ 'Number of Medicines':'Reduce # Medicines',
989
+ 'Cost per Medicine (₹)':'Reduce Med Cost',
990
+ 'Days Medicine Lasts':'Optimize Refill',
991
+ 'Total Dosage per Day (mg)':'Optimize Dosage',
992
+ 'Total Pills Given':'Reduce Pill Burden',
993
+ 'Medicine Availability (0-1)':'Improve Availability',
994
+ **{f:f.replace('Took Medicine','Ensure') for f in mod_set if f.startswith('Took Medicine')}
995
+ }
996
+ thresh=0.015
997
+ for f in mod_set:
998
+ na_v=abs(next((i['shap_value'] for i in valid_na if i['feature']==f),0))
999
+ r_v =abs(next((i['shap_value'] for i in valid_r if i['feature']==f),0))
1000
+ if na_v>thresh or r_v>thresh:
1001
+ interventions.append({'intervention':map_label.get(f,f),'na':na_v,'r':r_v})
1002
+ topi=sorted(interventions, key=lambda x:x['na']+x['r'],reverse=True)[:6]
1003
+ if topi:
1004
+ chart={'data':[
1005
+ {'type':'bar','orientation':'h','y':[i['intervention'] for i in topi],
1006
+ 'x':[i['na'] for i in topi],'name':'NA Reduction','text':[f"{i['na']:.3f}" for i in topi],
1007
+ 'textposition':'outside'},
1008
+ {'type':'bar','orientation':'h','y':[i['intervention'] for i in topi],
1009
+ 'x':[i['r'] for i in topi],'name':'R Reduction','text':[f"{i['r']:.3f}" for i in topi],
1010
+ 'textposition':'outside'}
1011
+ ],
1012
+ 'layout':{
1013
+ 'title':'Top Potential Intervention Impacts',
1014
+ 'barmode':'group','height':max(350,50*len(topi)),
1015
+ 'margin':{'t':50,'l':200,'b':50,'r':50},
1016
+ 'yaxis':{'autorange':'reversed','automargin':True},
1017
+ 'plot_bgcolor':plot_bgcolor,'paper_bgcolor':paper_bgcolor,
1018
+ 'font':{'family':font_family,'color':font_color},
1019
+ 'autosize':True
1020
+ }}
1021
+ visualizations['intervention_impact']=json.dumps(chart, cls=plotly.utils.PlotlyJSONEncoder)
1022
+ else:
1023
+ visualizations['intervention_impact']=create_placeholder_plot(
1024
+ 'Potential Intervention Impact',
1025
+ message="No significant interventions identified."
1026
+ )
1027
+ except Exception as e:
1028
+ print(f"Error interventions: {e}")
1029
+ visualizations['intervention_impact']=create_placeholder_plot('Potential Intervention Impact')
1030
+
1031
+ # --- 4. Network Graph ---
1032
+ try:
1033
+ import networkx as nx
1034
+ G=nx.Graph()
1035
+ # Use combined_impact from above
1036
+ N=8
1037
+ topN=sorted(combined_impact, key=lambda f:combined_impact[f], reverse=True)[:N]
1038
+ for f in topN:
1039
+ G.add_node(f, size=combined_impact[f])
1040
+ for i,a in enumerate(topN):
1041
+ for b in topN[i+1:]:
1042
+ w=combined_impact[a]+combined_impact[b]
1043
+ G.add_edge(a,b,weight=w)
1044
+ pos=nx.spring_layout(G,k=0.5,iterations=50,seed=42)
1045
+ ex,ey=[],[]
1046
+ for u,v in G.edges():
1047
+ x0,y0=pos[u]; x1,y1=pos[v]
1048
+ ex+=[x0,x1,None]; ey+=[y0,y1,None]
1049
+ nx_,ny_=[],[]
1050
+ ns=[G.nodes[n]['size']*50 for n in G.nodes()]
1051
+ for n in G.nodes():
1052
+ x,y=pos[n]
1053
+ nx_.append(x); ny_.append(y)
1054
+ net={
1055
+ 'data':[
1056
+ {'type':'scatter','x':ex,'y':ey,'mode':'lines','line':{'width':1,'color':'#888'},'hoverinfo':'none'},
1057
+ {'type':'scatter','x':nx_,'y':ny_,'mode':'markers+text',
1058
+ 'marker':{'size':ns,'color':'#1f77b4','opacity':0.8},
1059
+ 'text':list(G.nodes()),'textposition':'top center','hoverinfo':'text'}
1060
+ ],
1061
+ 'layout':{
1062
+ 'title':'Risk Factor Network',
1063
+ 'showlegend':False,
1064
+ 'xaxis':{'visible':False},'yaxis':{'visible':False},
1065
+ 'plot_bgcolor':paper_bgcolor,'paper_bgcolor':paper_bgcolor,
1066
+ 'margin':{'l':20,'r':20,'t':40,'b':20},'autosize':True
1067
+ }
1068
+ }
1069
+ visualizations['network_graph']=json.dumps(net, cls=plotly.utils.PlotlyJSONEncoder)
1070
+ except Exception as e:
1071
+ print(f"Error creating network graph: {e}")
1072
+ visualizations['network_graph']=create_placeholder_plot(
1073
+ 'Risk Factor Network',
1074
+ message="Network data unavailable."
1075
+ )
1076
+
1077
+ print("Finished generating additional visualizations.")
1078
+ return visualizations
1079
+
1080
+
1081
+ # --- Main Execution ---
1082
+ if __name__ == '__main__':
1083
+ print("Starting Flask application...")
1084
+ port = int(os.environ.get("PORT", 5000))
1085
+ # Set debug=False when deploying
1086
+ # Use debug=True for local development ONLY
1087
+ app.run(debug=False, host='0.0.0.0', port=port)