AssanaliAidarkhan commited on
Commit
9e90db6
Β·
verified Β·
1 Parent(s): 750248d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -91
app.py CHANGED
@@ -2,13 +2,17 @@ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPModel, CLIPProcessor
 
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import torch.nn.functional as F
8
  import json
 
 
9
 
10
  # Model repositories
11
  BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip"
 
12
 
13
  # Global variables
14
  biomedclip_model = None
@@ -16,6 +20,9 @@ biomedclip_processor = None
16
  biomedclip_id2label = {}
17
  qwen_model = None
18
  qwen_tokenizer = None
 
 
 
19
 
20
  class CLIPClassifier(nn.Module):
21
  def __init__(self, clip_model, num_classes):
@@ -53,140 +60,214 @@ def load_biomedclip():
53
  print(f"❌ BiomedCLIP error: {e}")
54
  return False
55
 
56
- def load_qwen_simple():
57
- """Load Qwen with minimal setup"""
58
- global qwen_model, qwen_tokenizer
59
 
60
  try:
61
- print("πŸ”„ Loading Qwen (simple)...")
62
 
63
- # Load Qwen directly
64
- qwen_tokenizer = AutoTokenizer.from_pretrained(
65
- "Qwen/Qwen1.5-0.5B-Chat",
66
- trust_remote_code=True
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
 
 
69
  qwen_model = AutoModelForCausalLM.from_pretrained(
70
  "Qwen/Qwen1.5-0.5B-Chat",
71
  torch_dtype=torch.float32,
72
  trust_remote_code=True
73
  )
74
 
75
- print("βœ… Qwen loaded!")
76
  return True
77
 
78
  except Exception as e:
79
- print(f"❌ Qwen error: {e}")
 
 
80
  return False
81
 
82
- def classify_mri(image):
83
- """Classify MRI (working code)"""
84
- if biomedclip_model is None or image is None:
85
- return None
 
 
86
 
87
  try:
88
- if image.mode != 'RGB':
89
- image = image.convert('RGB')
90
-
91
- inputs = biomedclip_processor(images=image, return_tensors="pt")
92
 
93
- with torch.no_grad():
94
- outputs = biomedclip_model(**inputs)
95
- logits = outputs['logits']
96
- probabilities = F.softmax(logits, dim=1)
97
 
98
- top_prob, top_idx = torch.max(probabilities, 1)
99
- class_idx = top_idx.item()
100
 
101
- if class_idx in biomedclip_id2label:
102
- class_name = biomedclip_id2label[class_idx]
103
- elif str(class_idx) in biomedclip_id2label:
104
- class_name = biomedclip_id2label[str(class_idx)]
105
- else:
106
- class_name = f"Class_{class_idx}"
107
 
108
- confidence = top_prob.item() * 100
 
 
 
 
 
 
 
109
 
110
- return class_name, confidence
111
 
112
  except Exception as e:
113
- print(f"Classification error: {e}")
114
- return None, None
115
 
116
- def generate_simple_advice(class_name, confidence):
117
- """Generate advice using Qwen (simple approach)"""
118
  global qwen_model, qwen_tokenizer
119
 
120
  if qwen_model is None:
121
  return "❌ Qwen model not loaded"
122
 
123
  try:
124
- print(f"πŸ”„ Generating advice for: {class_name}")
125
 
126
- # Simple medical knowledge lookup
127
- advice_map = {
128
- "partial_acl_injury": "Partial ACL injury detected. Recommendations: Rest and avoid pivoting activities. Apply ice for 15-20 minutes several times daily. Consider physical therapy consultation. Follow-up MRI in 6-8 weeks to monitor healing.",
129
- "complete_acl_tear": "Complete ACL tear detected. Urgent orthopedic consultation required. Likely surgical reconstruction needed. Immediate immobilization and avoid weight-bearing activities.",
130
- "acl_sprain": "ACL sprain detected. Conservative treatment with RICE protocol (Rest, Ice, Compression, Elevation). Physical therapy for strengthening. Gradual return to activities.",
131
- "normal": "ACL appears normal. Continue regular activities. If symptoms persist, consider clinical examination for other causes."
132
- }
133
-
134
- # Get base advice
135
- base_advice = advice_map.get(class_name.lower(), "Consult medical professional for evaluation.")
136
 
137
- # Create simple prompt for Qwen
138
- simple_prompt = f"Medical diagnosis: {class_name} with {confidence:.1f}% confidence. Provide brief clinical advice:"
139
 
140
  # Tokenize
141
- inputs = qwen_tokenizer(simple_prompt, return_tensors="pt")
 
 
142
 
143
  # Generate
144
  with torch.no_grad():
145
  outputs = qwen_model.generate(
146
  inputs.input_ids,
147
- max_new_tokens=100,
148
- temperature=0.8,
149
  do_sample=True,
150
- pad_token_id=qwen_tokenizer.eos_token_id
 
151
  )
152
 
153
- # Decode
154
- full_output = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
155
 
156
- # Extract just the generated part
157
- if simple_prompt in full_output:
158
- generated_advice = full_output.replace(simple_prompt, "").strip()
159
- else:
160
- generated_advice = full_output
161
 
162
- # Combine base advice with Qwen advice
163
- if generated_advice and len(generated_advice) > 10:
164
- combined_advice = f"**Clinical Guidelines:** {base_advice}\n\n**AI Analysis:** {generated_advice}"
165
- else:
166
- combined_advice = base_advice
167
 
168
- print(f"βœ… Generated advice: {generated_advice[:50]}...")
169
- return combined_advice
170
 
171
  except Exception as e:
172
- print(f"❌ Advice generation error: {e}")
173
- # Fallback to basic advice
174
- return advice_map.get(class_name.lower(), "Consult medical professional for evaluation.")
 
175
 
176
- def complete_pipeline(image):
177
- """Complete analysis pipeline"""
178
 
179
  if image is None:
180
  return "❌ Please upload an MRI scan", ""
181
 
182
- # Step 1: Classification
183
- class_name, confidence = classify_mri(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- if class_name is None:
186
- return "❌ Classification failed", ""
187
 
188
- # Step 2: Medical advice
189
- medical_advice = generate_simple_advice(class_name, confidence)
190
 
191
  # Format outputs
192
  classification_text = f"""
@@ -200,43 +281,65 @@ def complete_pipeline(image):
200
  """
201
 
202
  advice_text = f"""
203
- # πŸ₯ **Medical Recommendations**
 
 
 
204
 
205
- {medical_advice}
 
 
 
 
206
 
207
  ---
208
- ⚠️ **Disclaimer:** For educational purposes only. Consult medical professionals.
209
  """
210
 
211
  return classification_text, advice_text
212
 
213
  # Load models
 
214
  biomedclip_loaded = load_biomedclip()
215
- qwen_loaded = load_qwen_simple()
216
 
217
  # Create interface
218
- with gr.Blocks(title="Medical AI Pipeline") as app:
219
 
220
- gr.Markdown("# πŸ₯ Medical AI Analysis Pipeline")
221
- gr.Markdown("**BiomedCLIP** (Classification) + **Qwen** (Medical Advice)")
222
 
223
- status = f"Status: BiomedCLIP {'βœ…' if biomedclip_loaded else '❌'} | Qwen {'βœ…' if qwen_loaded else '❌'}"
224
- gr.Markdown(f"**{status}**")
225
 
226
  with gr.Row():
227
  with gr.Column():
228
  image_input = gr.Image(type="pil", label="πŸ“Έ Upload MRI Scan")
229
- analyze_btn = gr.Button("πŸ”¬ Complete Analysis", variant="primary")
230
 
231
  with gr.Column():
232
  classification_output = gr.Markdown(label="πŸ”¬ Classification")
233
- advice_output = gr.Markdown(label="πŸ₯ Medical Advice")
234
 
235
  analyze_btn.click(
236
- fn=complete_pipeline,
237
  inputs=image_input,
238
  outputs=[classification_output, advice_output]
239
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  if __name__ == "__main__":
242
  app.launch()
 
2
  import torch
3
  import torch.nn as nn
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPModel, CLIPProcessor
5
+ from sentence_transformers import SentenceTransformer
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
  import torch.nn.functional as F
9
  import json
10
+ import numpy as np
11
+ import faiss
12
 
13
  # Model repositories
14
  BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip"
15
+ QWEN_RAG_REPO = "AssanaliAidarkhan/qwen-medical-rag"
16
 
17
  # Global variables
18
  biomedclip_model = None
 
20
  biomedclip_id2label = {}
21
  qwen_model = None
22
  qwen_tokenizer = None
23
+ embedding_model = None
24
+ medical_knowledge = []
25
+ faiss_index = None
26
 
27
  class CLIPClassifier(nn.Module):
28
  def __init__(self, clip_model, num_classes):
 
60
  print(f"❌ BiomedCLIP error: {e}")
61
  return False
62
 
63
+ def load_rag_system():
64
+ """Load complete RAG system"""
65
+ global qwen_model, qwen_tokenizer, embedding_model, medical_knowledge, faiss_index
66
 
67
  try:
68
+ print("πŸ”„ Loading RAG system...")
69
 
70
+ # 1. Load medical knowledge base
71
+ try:
72
+ knowledge_path = hf_hub_download(repo_id=QWEN_RAG_REPO, filename="medical_knowledge.json")
73
+ with open(knowledge_path, 'r', encoding='utf-8') as f:
74
+ medical_knowledge = json.load(f)
75
+ print(f"βœ… Knowledge base: {len(medical_knowledge)} documents")
76
+ except Exception as e:
77
+ print(f"⚠️ Knowledge loading error: {e}, using fallback")
78
+ # Fallback knowledge base
79
+ medical_knowledge = [
80
+ {
81
+ "id": "doc1",
82
+ "title": "ЧастичноС поврСТдСния ΠΏΠ΅Ρ€Π΅Π΄Π½Π΅ΠΉ крСстообразной связки",
83
+ "content": "ΠŸΡ€ΠΈΠ·Π½Π°ΠΊΠΈ частичного поврСТдСния ΠΏΠ΅Ρ€Π΅Π΄Π½Π΅ΠΉ крСстообразной связки: ΡƒΡ‚ΠΎΠ»Ρ‰Π΅Π½ΠΈΠ΅, ΠΏΠΎΠ²Ρ‹ΡˆΠ΅Π½Π½Ρ‹ΠΉ сигнал ΠΏΠΎ Π’2, частичная дСзорганизация Π²ΠΎΠ»ΠΎΠΊΠΎΠ½, связка прослСТиваСтся ΠΏΠΎ Ρ…ΠΎΠ΄Ρƒ",
84
+ "category": "Partial ACL injury",
85
+ "advice": "Recommend conservative treatment, physical therapy, follow-up MRI in 6-8 weeks"
86
+ },
87
+ {
88
+ "id": "doc2",
89
+ "title": "ΠŸΠΎΠ»Π½Ρ‹ΠΉ Ρ€Π°Π·Ρ€Ρ‹Π² ΠΏΠ΅Ρ€Π΅Π΄Π½Π΅ΠΉ крСстообразной связки",
90
+ "content": "ΠŸΡ€ΠΈΠ·Π½Π°ΠΊΠΈ ΠΏΠΎΠ»Π½ΠΎΠ³ΠΎ Ρ€Π°Π·Ρ€Ρ‹Π²Π° ΠΏΠ΅Ρ€Π΅Π΄Π½Π΅ΠΉ крСстообразной связки: Π²ΠΎΠ»ΠΎΠΊΠ½Π° Π½Π΅ ΠΏΡ€ΠΎΡΠ»Π΅ΠΆΠΈΠ²Π°ΡŽΡ‚ΡΡ ΠΏΠΎ Ρ…ΠΎΠ΄Ρƒ, опрСдСляСтся Π·ΠΎΠ½Π° ΠΏΠΎΠ²Ρ‹ΡˆΠ΅Π½Π½ΠΎΠ³ΠΎ сигнала Π² ΠΏΡ€ΠΎΠ΅ΠΊΡ†ΠΈΠΈ связки, Π³Π΅ΠΌΠ°Ρ€Ρ‚Ρ€ΠΎΠ·",
91
+ "category": "Complete ACL tear",
92
+ "advice": "Urgent orthopedic consultation, likely requires ACL reconstruction surgery"
93
+ }
94
+ ]
95
+
96
+ # 2. Load embedding model
97
+ print("πŸ”„ Loading embeddings...")
98
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
99
+
100
+ # 3. Create embeddings and FAISS index
101
+ print("πŸ”„ Creating FAISS index...")
102
+ text_contents = []
103
+ for doc in medical_knowledge:
104
+ text = f"{doc.get('title', '')} {doc.get('content', '')} {doc.get('advice', '')}"
105
+ text_contents.append(text)
106
+
107
+ embeddings = embedding_model.encode(text_contents, convert_to_numpy=True)
108
+
109
+ # Create FAISS index
110
+ dimension = embeddings.shape[1]
111
+ faiss_index = faiss.IndexFlatIP(dimension)
112
+ faiss.normalize_L2(embeddings)
113
+ faiss_index.add(embeddings)
114
+
115
+ print(f"βœ… FAISS index created with {faiss_index.ntotal} documents")
116
 
117
+ # 4. Load Qwen
118
+ print("πŸ”„ Loading Qwen...")
119
+ qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", trust_remote_code=True)
120
  qwen_model = AutoModelForCausalLM.from_pretrained(
121
  "Qwen/Qwen1.5-0.5B-Chat",
122
  torch_dtype=torch.float32,
123
  trust_remote_code=True
124
  )
125
 
126
+ print("βœ… RAG system loaded completely!")
127
  return True
128
 
129
  except Exception as e:
130
+ print(f"❌ RAG loading error: {e}")
131
+ import traceback
132
+ print(traceback.format_exc())
133
  return False
134
 
135
+ def retrieve_relevant_knowledge(classification_result):
136
+ """Retrieve relevant medical documents"""
137
+ global embedding_model, medical_knowledge, faiss_index
138
+
139
+ if faiss_index is None:
140
+ return [], "No knowledge base available"
141
 
142
  try:
143
+ # Create query for retrieval
144
+ query = f"Medical diagnosis {classification_result} treatment recommendations clinical advice"
 
 
145
 
146
+ # Get query embedding
147
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True)
148
+ faiss.normalize_L2(query_embedding)
 
149
 
150
+ # Search FAISS index
151
+ scores, indices = faiss_index.search(query_embedding, 2) # Top 2 documents
152
 
153
+ # Get relevant documents
154
+ retrieved_docs = []
155
+ context_text = ""
 
 
 
156
 
157
+ for score, idx in zip(scores[0], indices[0]):
158
+ if idx != -1 and idx < len(medical_knowledge):
159
+ doc = medical_knowledge[idx]
160
+ retrieved_docs.append((doc, float(score)))
161
+
162
+ context_text += f"Medical Knowledge: {doc.get('content', '')}\n"
163
+ context_text += f"Clinical Advice: {doc.get('advice', '')}\n"
164
+ context_text += f"Category: {doc.get('category', '')}\n\n"
165
 
166
+ return retrieved_docs, context_text
167
 
168
  except Exception as e:
169
+ print(f"❌ Retrieval error: {e}")
170
+ return [], f"Retrieval error: {e}"
171
 
172
+ def generate_qwen_advice(classification_result, retrieved_context):
173
+ """Generate medical advice using Qwen with RAG context"""
174
  global qwen_model, qwen_tokenizer
175
 
176
  if qwen_model is None:
177
  return "❌ Qwen model not loaded"
178
 
179
  try:
180
+ print("πŸ”„ Generating Qwen advice...")
181
 
182
+ # Create comprehensive prompt
183
+ prompt = f"""You are a medical AI assistant. Based on the MRI classification and medical knowledge provided, give clinical recommendations.
184
+
185
+ MRI Classification: {classification_result}
186
+
187
+ Retrieved Medical Knowledge:
188
+ {retrieved_context}
189
+
190
+ Provide specific clinical recommendations including treatment options and follow-up care:"""
 
191
 
192
+ print(f"πŸ“ Prompt length: {len(prompt)} characters")
 
193
 
194
  # Tokenize
195
+ inputs = qwen_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
196
+
197
+ print(f"πŸ”§ Input tokens: {inputs.input_ids.shape}")
198
 
199
  # Generate
200
  with torch.no_grad():
201
  outputs = qwen_model.generate(
202
  inputs.input_ids,
203
+ max_new_tokens=120,
204
+ temperature=0.7,
205
  do_sample=True,
206
+ pad_token_id=qwen_tokenizer.eos_token_id,
207
+ eos_token_id=qwen_tokenizer.eos_token_id
208
  )
209
 
210
+ # Decode only the new tokens
211
+ generated_tokens = outputs[0][inputs.input_ids.shape[1]:]
212
+ generated_text = qwen_tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
213
 
214
+ print(f"βœ… Generated: {generated_text[:100]}...")
 
 
 
 
215
 
216
+ if len(generated_text) < 10:
217
+ return "No specific recommendations generated. Consult medical professional."
 
 
 
218
 
219
+ return generated_text
 
220
 
221
  except Exception as e:
222
+ print(f"❌ Qwen generation error: {e}")
223
+ import traceback
224
+ print(traceback.format_exc())
225
+ return f"Generation error: {e}"
226
 
227
+ def complete_analysis(image):
228
+ """Complete pipeline with RAG"""
229
 
230
  if image is None:
231
  return "❌ Please upload an MRI scan", ""
232
 
233
+ # Step 1: Classification
234
+ try:
235
+ if biomedclip_model is None:
236
+ return "❌ BiomedCLIP not loaded", ""
237
+
238
+ if image.mode != 'RGB':
239
+ image = image.convert('RGB')
240
+
241
+ inputs = biomedclip_processor(images=image, return_tensors="pt")
242
+
243
+ with torch.no_grad():
244
+ outputs = biomedclip_model(**inputs)
245
+ logits = outputs['logits']
246
+ probabilities = F.softmax(logits, dim=1)
247
+
248
+ top_prob, top_idx = torch.max(probabilities, 1)
249
+ class_idx = top_idx.item()
250
+
251
+ if class_idx in biomedclip_id2label:
252
+ class_name = biomedclip_id2label[class_idx]
253
+ elif str(class_idx) in biomedclip_id2label:
254
+ class_name = biomedclip_id2label[str(class_idx)]
255
+ else:
256
+ class_name = f"Class_{class_idx}"
257
+
258
+ confidence = top_prob.item() * 100
259
+ classification_result = f"{class_name} ({confidence:.1f}% confidence)"
260
+
261
+ print(f"βœ… Classification: {classification_result}")
262
+
263
+ except Exception as e:
264
+ return f"❌ Classification error: {e}", ""
265
 
266
+ # Step 2: RAG retrieval
267
+ retrieved_docs, context = retrieve_relevant_knowledge(classification_result)
268
 
269
+ # Step 3: Qwen generation
270
+ qwen_advice = generate_qwen_advice(classification_result, context)
271
 
272
  # Format outputs
273
  classification_text = f"""
 
281
  """
282
 
283
  advice_text = f"""
284
+ # πŸ₯ **AI-Generated Medical Recommendations**
285
+
286
+ ## πŸ€– **Qwen Analysis:**
287
+ {qwen_advice}
288
 
289
+ ## πŸ“š **Retrieved Medical Knowledge:**
290
+ {context if context else "No relevant knowledge retrieved"}
291
+
292
+ ## πŸ“‹ **Retrieved Documents:**
293
+ {len(retrieved_docs)} documents found and used for advice generation
294
 
295
  ---
296
+ ⚠️ **Disclaimer:** For educational purposes only. Always consult medical professionals.
297
  """
298
 
299
  return classification_text, advice_text
300
 
301
  # Load models
302
+ print("πŸš€ Loading complete pipeline...")
303
  biomedclip_loaded = load_biomedclip()
304
+ rag_loaded = load_rag_system()
305
 
306
  # Create interface
307
+ with gr.Blocks(title="Medical RAG Pipeline") as app:
308
 
309
+ gr.Markdown("# πŸ₯ Medical AI RAG Pipeline")
310
+ gr.Markdown("**BiomedCLIP** β†’ **RAG Retrieval** β†’ **Qwen Generation**")
311
 
312
+ status = f"BiomedCLIP: {'βœ…' if biomedclip_loaded else '❌'} | RAG: {'βœ…' if rag_loaded else '❌'}"
313
+ gr.Markdown(f"**Status:** {status}")
314
 
315
  with gr.Row():
316
  with gr.Column():
317
  image_input = gr.Image(type="pil", label="πŸ“Έ Upload MRI Scan")
318
+ analyze_btn = gr.Button("πŸ”¬ Complete RAG Analysis", variant="primary")
319
 
320
  with gr.Column():
321
  classification_output = gr.Markdown(label="πŸ”¬ Classification")
322
+ advice_output = gr.Markdown(label="πŸ₯ RAG-Generated Advice")
323
 
324
  analyze_btn.click(
325
+ fn=complete_analysis,
326
  inputs=image_input,
327
  outputs=[classification_output, advice_output]
328
  )
329
+
330
+ gr.Markdown("""
331
+ ### πŸ”„ **RAG Pipeline Process:**
332
+ 1. **Image Classification** - BiomedCLIP analyzes MRI
333
+ 2. **Knowledge Retrieval** - Find relevant medical documents
334
+ 3. **Context Generation** - Qwen uses retrieved knowledge
335
+ 4. **Advice Output** - AI-generated clinical recommendations
336
+
337
+ ### πŸ“š **Knowledge Base:**
338
+ - ACL injury types and symptoms
339
+ - Treatment recommendations
340
+ - Clinical guidelines
341
+ - Follow-up protocols
342
+ """)
343
 
344
  if __name__ == "__main__":
345
  app.launch()