HebaElshimy commited on
Commit
1621c4e
Β·
verified Β·
1 Parent(s): fee9667

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +621 -564
app.py CHANGED
@@ -1,469 +1,357 @@
1
  import gradio as gr
2
  import pandas as pd
3
- import numpy as np
 
 
4
  import torch
5
- from transformers import (
6
- pipeline,
7
- AutoTokenizer,
8
- AutoModel,
9
- AutoModelForSequenceClassification
10
- )
11
  from sentence_transformers import SentenceTransformer, CrossEncoder
 
 
12
  import re
13
- from typing import List, Dict, Tuple, Optional
14
- import warnings
15
- warnings.filterwarnings('ignore')
16
 
17
  # ============================================================================
18
- # ADVANCED MODEL INITIALIZATION
19
  # ============================================================================
20
 
21
- class AdvancedMedicalScreener:
22
- def __init__(self):
23
- """Initialize all advanced NLP models for medical literature screening"""
24
- print("πŸš€ Initializing Advanced Medical Screening Models...")
25
-
26
- # 1. Biomedical language model for embeddings
27
- print("Loading PubMedBERT for medical text understanding...")
28
- self.pubmed_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
29
- self.pubmed_model = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
30
-
31
- # 2. Cross-encoder for accurate semantic similarity
32
- print("Loading Cross-Encoder for semantic matching...")
33
- self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', max_length=512)
34
-
35
- # 3. Zero-shot classifier for criteria matching
36
- print("Loading Zero-Shot Classifier...")
37
- self.zero_shot = pipeline(
38
- "zero-shot-classification",
39
- model="facebook/bart-large-mnli",
40
- device=0 if torch.cuda.is_available() else -1
41
- )
42
-
43
- # 4. Sentence transformer for fast similarity
44
- print("Loading Sentence Transformer...")
45
- self.sentence_model = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb')
46
-
47
- # 5. Medical NER for entity extraction (optional, lightweight)
48
- print("Loading Medical NER model...")
49
- try:
50
- self.ner_pipeline = pipeline(
51
- "ner",
52
- model="dmis-lab/biobert-base-cased-v1.2",
53
- aggregation_strategy="simple"
54
- )
55
- except:
56
- self.ner_pipeline = None
57
- print("Note: Medical NER model not available, using fallback")
58
-
59
- print("βœ… All models loaded successfully!")
60
-
61
- # Medical terminology expansions
62
- self.medical_synonyms = {
63
- 'rct': ['randomized controlled trial', 'randomised controlled trial', 'randomized clinical trial'],
64
- 'pain': ['pain', 'nociception', 'analgesia', 'hyperalgesia', 'allodynia', 'neuropathic pain',
65
- 'chronic pain', 'acute pain', 'postoperative pain', 'pain management'],
66
- 'surgery': ['surgery', 'surgical', 'operation', 'operative', 'postoperative', 'perioperative',
67
- 'preoperative', 'surgical procedure', 'surgical intervention'],
68
- 'study design': ['study design', 'trial design', 'research design', 'methodology',
69
- 'randomized', 'controlled', 'cohort', 'case-control', 'cross-sectional',
70
- 'prospective', 'retrospective', 'observational', 'experimental'],
71
- 'systematic review': ['systematic review', 'meta-analysis', 'meta analysis', 'evidence synthesis'],
72
- 'case report': ['case report', 'case study', 'case series', 'case presentation'],
73
- 'clinical trial': ['clinical trial', 'clinical study', 'trial', 'intervention study'],
74
- }
75
-
76
- # Study design hierarchy for classification
77
- self.study_designs = {
78
- 'high_quality': ['randomized controlled trial', 'systematic review', 'meta-analysis'],
79
- 'moderate_quality': ['cohort study', 'case-control study', 'controlled trial'],
80
- 'low_quality': ['case report', 'case series', 'opinion', 'editorial'],
81
- 'observational': ['cohort', 'case-control', 'cross-sectional', 'observational'],
82
- 'experimental': ['randomized', 'experimental', 'intervention', 'trial']
83
- }
84
-
85
- def get_pubmed_embedding(self, text: str) -> np.ndarray:
86
- """Get PubMedBERT embedding for medical text"""
87
- inputs = self.pubmed_tokenizer(
88
- text,
89
- return_tensors="pt",
90
- truncation=True,
91
- max_length=512,
92
- padding=True
93
- )
94
-
95
  with torch.no_grad():
96
- outputs = self.pubmed_model(**inputs)
97
- # Use CLS token embedding
98
  embedding = outputs.last_hidden_state[:, 0, :].numpy()
99
-
100
  return embedding.squeeze()
 
 
101
 
102
- def expand_medical_terms(self, term: str) -> List[str]:
103
- """Expand medical terms with synonyms and related concepts"""
104
- term_lower = term.lower()
105
- expanded = [term]
106
-
107
- # Check for known medical synonyms
108
- for key, synonyms in self.medical_synonyms.items():
109
- if key in term_lower or any(syn in term_lower for syn in synonyms):
110
- expanded.extend(synonyms)
111
-
112
- # Add variations
113
- if 'pain' in term_lower:
114
- expanded.extend(['analgesic', 'nociceptive', 'painful'])
115
- if 'surgery' in term_lower or 'surgical' in term_lower:
116
- expanded.extend(['surgeon', 'resection', 'excision', 'incision'])
117
-
118
- return list(set(expanded))
119
-
120
- def parse_advanced_criteria(self, criteria_text: str) -> Dict:
121
- """Advanced parsing of inclusion/exclusion criteria with medical understanding"""
122
- criteria = {
123
- 'population': [],
124
- 'intervention': [],
125
- 'comparator': [],
126
- 'outcomes': [],
127
- 'study_design': [],
128
- 'include_general': [],
129
- 'exclude_general': [],
130
- 'pain_related': [],
131
- 'surgery_related': []
132
- }
133
-
134
- lines = criteria_text.split('\n')
135
- current_section = None
136
- is_exclusion = False
137
-
138
- for line in lines:
139
- line_clean = line.strip()
140
- line_lower = line_clean.lower()
141
-
142
- if not line_clean:
143
- continue
144
-
145
- # Detect exclusion context
146
- if 'exclude' in line_lower:
147
- is_exclusion = True
148
- current_section = 'exclude_general'
149
- elif 'include' in line_lower:
150
- is_exclusion = False
151
- current_section = 'include_general'
152
-
153
- # Detect PICOS sections
154
- elif any(term in line_lower for term in ['population:', 'participants:', 'patients:']):
155
- current_section = 'population'
156
- elif any(term in line_lower for term in ['intervention:', 'exposure:', 'treatment:']):
157
- current_section = 'intervention'
158
- elif any(term in line_lower for term in ['comparator:', 'control:', 'comparison:']):
159
- current_section = 'comparator'
160
- elif any(term in line_lower for term in ['outcome:', 'endpoint:', 'measure:']):
161
- current_section = 'outcomes'
162
- elif any(term in line_lower for term in ['study design:', 'design:', 'study type:', 'methodology:']):
163
- current_section = 'study_design'
164
-
165
- # Special detection for pain and surgery
166
- elif 'pain' in line_lower:
167
- current_section = 'pain_related'
168
- elif any(term in line_lower for term in ['surgery', 'surgical', 'operation']):
169
- current_section = 'surgery_related'
170
-
171
- # Extract criteria items
172
- elif current_section:
173
- # Handle bullet points or dashes
174
- if line_clean.startswith(('-', 'β€’', '*', 'Β·')):
175
- item = line_clean[1:].strip()
176
- if item:
177
- # Expand medical terms
178
- expanded_items = self.expand_medical_terms(item)
179
- criteria[current_section].extend(expanded_items)
180
- # Handle comma-separated items
181
- elif ',' in line_clean and ':' not in line_clean:
182
- items = [i.strip() for i in line_clean.split(',')]
183
- for item in items:
184
- if item and len(item) > 2:
185
- expanded_items = self.expand_medical_terms(item)
186
- criteria[current_section].extend(expanded_items)
187
- # Handle single items
188
- elif line_clean and not any(marker in line_lower for marker in [':', 'population', 'intervention', 'outcome']):
189
- expanded_items = self.expand_medical_terms(line_clean)
190
- criteria[current_section].extend(expanded_items)
191
-
192
- # Remove duplicates
193
- for key in criteria:
194
- criteria[key] = list(set(criteria[key]))
195
-
196
- return criteria
197
 
198
- def cross_encoder_score(self, text: str, criteria: str) -> float:
199
- """Calculate cross-encoder similarity score"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  try:
201
- score = self.cross_encoder.predict([[text, criteria]])
202
- # Normalize to 0-1 range
203
- return float(1 / (1 + np.exp(-score[0])))
204
  except:
205
- return 0.0
206
-
207
- def zero_shot_classify(self, text: str, labels: List[str], hypothesis_template: str = "This study is about {}") -> Dict:
208
- """Perform zero-shot classification with custom hypothesis"""
209
- if not labels:
210
- return {}
211
-
212
  try:
213
- result = self.zero_shot(
214
- text,
215
- candidate_labels=labels,
216
- hypothesis_template=hypothesis_template,
217
- multi_label=True
218
- )
219
-
220
- # Convert to dictionary with scores
221
- scores = {}
222
- for label, score in zip(result['labels'], result['scores']):
223
- scores[label] = score
224
- return scores
225
  except:
226
- return {}
227
-
228
- def evaluate_study_design(self, text: str) -> Dict:
229
- """Evaluate study design quality and type"""
230
- design_labels = [
231
- 'randomized controlled trial',
232
- 'systematic review',
233
- 'meta-analysis',
234
- 'cohort study',
235
- 'case-control study',
236
- 'cross-sectional study',
237
- 'case report',
238
- 'observational study',
239
- 'experimental study'
240
- ]
241
-
242
- scores = self.zero_shot_classify(
243
- text,
244
- design_labels,
245
- hypothesis_template="This is a {}"
246
- )
247
-
248
- # Determine quality level
249
- quality = 'unknown'
250
- max_design = max(scores.items(), key=lambda x: x[1])[0] if scores else ''
251
-
252
- for level, designs in self.study_designs.items():
253
- if any(design in max_design.lower() for design in designs):
254
- quality = level
255
- break
256
-
257
- return {
258
- 'design_scores': scores,
259
- 'primary_design': max_design,
260
- 'quality_level': quality
261
- }
262
-
263
- def evaluate_pain_surgery_relevance(self, text: str) -> Dict:
264
- """Specifically evaluate pain and surgery relevance"""
265
- # Pain-related evaluation
266
- pain_terms = [
267
- 'chronic pain', 'acute pain', 'postoperative pain',
268
- 'pain management', 'analgesia', 'neuropathic pain',
269
- 'pain relief', 'pain control', 'pain assessment'
270
- ]
271
-
272
- pain_scores = self.zero_shot_classify(
273
- text,
274
- pain_terms,
275
- hypothesis_template="This study involves {}"
276
- )
277
-
278
- # Surgery-related evaluation
279
- surgery_terms = [
280
- 'surgical procedure', 'postoperative', 'perioperative',
281
- 'surgical intervention', 'operation', 'surgical outcomes',
282
- 'surgical complications', 'surgical technique'
283
- ]
284
-
285
- surgery_scores = self.zero_shot_classify(
286
- text,
287
- surgery_terms,
288
- hypothesis_template="This study involves {}"
289
- )
290
-
291
- return {
292
- 'pain_relevance': max(pain_scores.values()) if pain_scores else 0,
293
- 'surgery_relevance': max(surgery_scores.values()) if surgery_scores else 0,
294
- 'pain_terms': pain_scores,
295
- 'surgery_terms': surgery_scores
296
- }
297
-
298
- def stage1_advanced_classification(self, title: str, abstract: str, criteria_text: str) -> Dict:
299
- """Advanced Stage 1 classification using multiple NLP models"""
300
-
301
- # Combine text
302
- study_text = f"{title} {abstract}"
303
- if len(study_text.strip()) < 20:
304
- return {
305
- 'decision': 'UNCLEAR',
306
- 'confidence': 0,
307
- 'reasoning': 'Insufficient text for analysis',
308
- 'detailed_scores': {}
309
- }
310
-
311
- # Parse criteria with medical understanding
312
- criteria = self.parse_advanced_criteria(criteria_text)
313
-
314
- # Initialize scoring components
315
- scores = {
316
- 'population': 0,
317
- 'intervention': 0,
318
- 'comparator': 0,
319
- 'outcomes': 0,
320
- 'study_design': 0,
321
- 'inclusion': 0,
322
- 'exclusion': 0,
323
- 'pain_relevance': 0,
324
- 'surgery_relevance': 0
325
- }
326
-
327
- reasoning_parts = []
328
-
329
- # 1. Evaluate PICOS elements using cross-encoder
330
- for element in ['population', 'intervention', 'comparator', 'outcomes']:
331
- if criteria[element]:
332
- element_scores = []
333
- for criterion in criteria[element][:5]: # Limit to top 5 to avoid overload
334
- score = self.cross_encoder_score(study_text, criterion)
335
- element_scores.append(score)
336
-
337
- if element_scores:
338
- scores[element] = max(element_scores)
339
- if scores[element] > 0.5:
340
- best_match = criteria[element][element_scores.index(max(element_scores))]
341
- reasoning_parts.append(f"{element.capitalize()}: '{best_match}' ({scores[element]:.2f})")
342
-
343
- # 2. Evaluate study design
344
- design_eval = self.evaluate_study_design(study_text)
345
- scores['study_design'] = max(design_eval['design_scores'].values()) if design_eval['design_scores'] else 0
346
- if scores['study_design'] > 0.5:
347
- reasoning_parts.append(f"Study Design: {design_eval['primary_design']} ({scores['study_design']:.2f})")
348
-
349
- # 3. Evaluate pain and surgery relevance if applicable
350
- if criteria['pain_related'] or 'pain' in criteria_text.lower():
351
- pain_surgery_eval = self.evaluate_pain_surgery_relevance(study_text)
352
- scores['pain_relevance'] = pain_surgery_eval['pain_relevance']
353
- if scores['pain_relevance'] > 0.5:
354
- reasoning_parts.append(f"Pain Relevance: {scores['pain_relevance']:.2f}")
355
-
356
- if criteria['surgery_related'] or 'surgery' in criteria_text.lower():
357
- pain_surgery_eval = self.evaluate_pain_surgery_relevance(study_text)
358
- scores['surgery_relevance'] = pain_surgery_eval['surgery_relevance']
359
- if scores['surgery_relevance'] > 0.5:
360
- reasoning_parts.append(f"Surgery Relevance: {scores['surgery_relevance']:.2f}")
361
-
362
- # 4. Evaluate inclusion criteria
363
- if criteria['include_general']:
364
- inclusion_scores = []
365
- for criterion in criteria['include_general'][:3]:
366
- score = self.cross_encoder_score(study_text, criterion)
367
- inclusion_scores.append(score)
368
- scores['inclusion'] = max(inclusion_scores) if inclusion_scores else 0
369
- if scores['inclusion'] > 0.5:
370
- reasoning_parts.append(f"Inclusion Match: {scores['inclusion']:.2f}")
371
-
372
- # 5. Evaluate exclusion criteria
373
- if criteria['exclude_general']:
374
- exclusion_scores = []
375
- for criterion in criteria['exclude_general'][:3]:
376
- score = self.cross_encoder_score(study_text, criterion)
377
- exclusion_scores.append(score)
378
- scores['exclusion'] = max(exclusion_scores) if exclusion_scores else 0
379
- if scores['exclusion'] > 0.6:
380
- reasoning_parts.append(f"EXCLUSION Match: {scores['exclusion']:.2f}")
381
-
382
- # 6. Check for low-quality study designs
383
- if design_eval.get('quality_level') == 'low_quality':
384
- scores['exclusion'] = max(scores['exclusion'], 0.7)
385
- reasoning_parts.append(f"Low Quality Design: {design_eval['primary_design']}")
386
-
387
- # Decision Logic with Confidence Calibration
388
- decision, confidence = self._make_decision_stage1(scores, design_eval)
389
-
390
- # Format reasoning
391
- if not reasoning_parts:
392
- reasoning_parts.append("No strong matches found")
393
- reasoning = f"Stage 1 {decision}: {'; '.join(reasoning_parts)}"
394
-
395
- return {
396
- 'decision': decision,
397
- 'confidence': confidence,
398
- 'reasoning': reasoning,
399
- 'detailed_scores': scores,
400
- 'study_design': design_eval.get('primary_design', 'Unknown'),
401
- 'quality_level': design_eval.get('quality_level', 'Unknown')
402
- }
403
-
404
- def _make_decision_stage1(self, scores: Dict, design_eval: Dict) -> Tuple[str, int]:
405
- """Make Stage 1 decision based on scores with calibrated confidence"""
406
-
407
- # Strong exclusion criteria
408
- if scores['exclusion'] > 0.65:
409
- confidence = min(int(scores['exclusion'] * 100), 90)
410
- return 'EXCLUDE', confidence
411
-
412
- # Low quality design exclusion
413
- if design_eval.get('quality_level') == 'low_quality' and scores['study_design'] > 0.7:
414
- return 'EXCLUDE', 75
415
-
416
- # Calculate inclusion strength
417
- picos_scores = [scores['population'], scores['intervention'], scores['outcomes']]
418
- relevant_picos = sum(1 for s in picos_scores if s > 0.5)
419
- avg_picos = np.mean([s for s in picos_scores if s > 0.3]) if any(s > 0.3 for s in picos_scores) else 0
420
-
421
- # Strong inclusion - multiple PICOS matches
422
- if relevant_picos >= 2 and avg_picos > 0.6:
423
- confidence = min(int(avg_picos * 85), 85)
424
- return 'INCLUDE', confidence
425
-
426
- # Moderate inclusion - some relevant matches
427
- if relevant_picos >= 1 or scores['inclusion'] > 0.6:
428
- best_score = max(scores['population'], scores['intervention'], scores['outcomes'], scores['inclusion'])
429
- confidence = min(int(best_score * 75), 75)
430
- return 'INCLUDE', confidence
431
-
432
- # Special consideration for pain/surgery studies
433
- if (scores['pain_relevance'] > 0.6 or scores['surgery_relevance'] > 0.6) and \
434
- design_eval.get('quality_level') in ['high_quality', 'moderate_quality']:
435
- confidence = 70
436
- return 'INCLUDE', confidence
437
-
438
- # Weak matches - need manual review
439
- if any(s > 0.4 for s in [scores['population'], scores['intervention'], scores['outcomes']]):
440
- return 'UNCLEAR', 50
441
-
442
- # No relevant matches
443
- return 'EXCLUDE', 60
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  # ============================================================================
447
- # GRADIO INTERFACE FUNCTIONS
448
  # ============================================================================
449
 
450
- # Initialize the screener globally
451
- screener = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
- def initialize_screener():
454
- """Initialize the screener if not already done"""
455
- global screener
456
- if screener is None:
457
- screener = AdvancedMedicalScreener()
458
- return screener
459
 
460
- def process_stage1_advanced(file, title_col, abstract_col, criteria, sample_size):
461
- """Process Stage 1 screening with advanced NLP models"""
462
  try:
463
- # Initialize screener
464
- model = initialize_screener()
465
-
466
- # Read CSV
467
  df = pd.read_csv(file.name)
468
  if sample_size < len(df):
469
  df = df.head(sample_size)
@@ -476,16 +364,13 @@ def process_stage1_advanced(file, title_col, abstract_col, criteria, sample_size
476
  if not title and not abstract:
477
  continue
478
 
479
- # Use advanced classification
480
- classification = model.stage1_advanced_classification(title, abstract, criteria)
481
 
482
  result = {
483
  'Study_ID': idx + 1,
484
  'Title': title[:100] + "..." if len(title) > 100 else title,
485
  'Stage1_Decision': classification['decision'],
486
  'Stage1_Confidence': f"{classification['confidence']}%",
487
- 'Study_Design': classification.get('study_design', 'Unknown'),
488
- 'Quality_Level': classification.get('quality_level', 'Unknown'),
489
  'Stage1_Reasoning': classification['reasoning'],
490
  'Ready_for_Stage2': 'Yes' if classification['decision'] == 'INCLUDE' else 'No',
491
  'Full_Title': title,
@@ -495,38 +380,29 @@ def process_stage1_advanced(file, title_col, abstract_col, criteria, sample_size
495
 
496
  results_df = pd.DataFrame(results)
497
 
498
- # Generate summary
499
  total = len(results_df)
500
  included = len(results_df[results_df['Stage1_Decision'] == 'INCLUDE'])
501
  excluded = len(results_df[results_df['Stage1_Decision'] == 'EXCLUDE'])
502
  unclear = len(results_df[results_df['Stage1_Decision'] == 'UNCLEAR'])
503
 
504
- # Quality breakdown
505
- quality_counts = results_df['Quality_Level'].value_counts().to_dict()
506
- quality_summary = "\n".join([f" - {level}: {count}" for level, count in quality_counts.items()])
507
 
508
  summary = f"""
509
- ## πŸ“Š Advanced Stage 1 Results (AI-Powered Medical Screening)
510
-
511
- **Screening Complete with Advanced NLP Models:**
512
- - **Total Studies Analyzed:** {total}
513
- - **βœ… Include for Stage 2:** {included} ({included/total*100:.1f}%)
514
- - **❌ Exclude:** {excluded} ({excluded/total*100:.1f}%)
515
- - **⚠️ Needs Manual Review:** {unclear} ({unclear/total*100:.1f}%)
516
 
517
- **Study Quality Distribution:**
518
- {quality_summary}
519
 
520
- **Models Used:**
521
- - PubMedBERT for medical text understanding
522
- - Cross-encoder for semantic similarity
523
- - Zero-shot classification for criteria matching
524
- - Medical NER for entity extraction
525
 
526
  **Next Steps:**
527
  1. Review {unclear} studies marked as UNCLEAR
528
  2. Proceed to Stage 2 with {included} included studies
529
- 3. Consider manual validation of borderline cases
530
  """
531
 
532
  return summary, results_df, results_df.to_csv(index=False)
@@ -534,28 +410,103 @@ def process_stage1_advanced(file, title_col, abstract_col, criteria, sample_size
534
  except Exception as e:
535
  return f"Error: {str(e)}", None, ""
536
 
537
- def create_advanced_interface():
538
- """Create the Gradio interface with advanced NLP capabilities"""
539
- with gr.Blocks(title="πŸ”¬ Advanced Medical Literature Screening", theme=gr.themes.Soft()) as interface:
 
540
 
541
- gr.Markdown("""
542
- # πŸ”¬ Advanced Medical Literature Screening with AI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
 
544
- **State-of-the-art NLP models for systematic review screening**
 
545
 
546
- This tool uses advanced transformer models specifically trained on medical literature:
547
- - **PubMedBERT**: Understands medical terminology and concepts
548
- - **Cross-Encoders**: Accurate semantic matching for criteria
549
- - **Zero-Shot Classification**: Flexible criteria evaluation
550
- - **Medical NER**: Extracts medical entities automatically
551
 
552
- Optimized for **pain**, **surgery**, and **study design** criteria, with general medical understanding.
 
 
553
  """)
554
 
555
  with gr.Tabs():
556
 
557
  # STAGE 1 TAB
558
- with gr.TabItem("πŸ“‹ Stage 1: Advanced Title/Abstract Screening"):
559
  with gr.Row():
560
  with gr.Column(scale=1):
561
  gr.Markdown("### πŸ“ Upload Study Data")
@@ -570,113 +521,200 @@ def create_advanced_interface():
570
  stage1_title_col = gr.Dropdown(label="Title Column", choices=[], interactive=True)
571
  stage1_abstract_col = gr.Dropdown(label="Abstract Column", choices=[], interactive=True)
572
 
573
- stage1_sample = gr.Slider(
574
- label="Studies to Process",
575
- minimum=5,
576
- maximum=500,
577
- value=100,
578
- step=5,
579
- info="Processing time increases with more studies"
580
- )
581
 
582
  with gr.Column(scale=1):
583
- gr.Markdown("### 🎯 Inclusion/Exclusion Criteria")
584
 
585
  stage1_criteria = gr.Textbox(
586
- label="Enter your criteria (understands medical terminology)",
587
  value="""POPULATION:
588
- - Adult patients
589
- - Chronic pain patients
590
- - Surgical patients
591
 
592
  INTERVENTION:
593
- - Pain management interventions
594
- - Surgical procedures
595
- - Analgesic treatments
596
 
597
  OUTCOMES:
598
- - Pain intensity
599
- - Pain relief
600
- - Functional outcomes
601
- - Quality of life
602
 
603
  STUDY DESIGN:
604
  - Randomized controlled trials
605
- - Systematic reviews
606
  - Cohort studies
607
- - NOT case reports
608
 
609
  EXCLUDE:
610
  - Animal studies
611
- - Pediatric only
612
  - Case reports
613
- - Editorials""",
614
- lines=20,
615
- info="The AI understands medical synonyms and related terms"
616
  )
617
 
618
- with gr.Row():
619
- stage1_process_btn = gr.Button(
620
- "πŸš€ Start Advanced AI Screening",
621
- variant="primary",
622
- scale=2
623
- )
624
- gr.Markdown("*First run may take longer to load models*", scale=1)
625
 
626
  stage1_results = gr.Markdown()
627
- stage1_table = gr.Dataframe(
628
- label="Stage 1 Results with Quality Assessment",
629
- wrap=True
630
- )
631
  stage1_download_data = gr.Textbox(visible=False)
632
- stage1_download_btn = gr.DownloadButton(
633
- label="πŸ’Ύ Download Stage 1 Results",
634
- visible=False
635
- )
636
 
637
- # HELP TAB
638
- with gr.TabItem("❓ Help & Guidelines"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  gr.Markdown("""
640
- ## πŸ€– Advanced Features Explained
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
- ### **Medical Understanding**
643
- The system automatically:
644
- - Recognizes medical synonyms (e.g., RCT = randomized controlled trial)
645
- - Understands pain-related terms (nociception, analgesia, hyperalgesia)
646
- - Identifies surgical concepts (perioperative, postoperative, resection)
647
- - Evaluates study quality based on design
648
 
649
- ### **How to Write Effective Criteria**
 
 
 
 
650
 
651
- 1. **Be specific but comprehensive:**
652
- - βœ… "chronic pain lasting > 3 months"
653
- - βœ… "postoperative pain management"
654
- - ❌ "pain" (too vague)
655
 
656
- 2. **Use medical terms freely:**
657
- - The AI understands medical terminology
658
- - It will automatically expand terms with synonyms
659
- - Example: "surgery" β†’ surgical, operation, resection, etc.
660
 
661
- 3. **Specify study designs clearly:**
662
- - High quality: RCT, systematic review, meta-analysis
663
- - Moderate: cohort, case-control
664
- - Low: case reports, opinions
 
665
 
666
- ### **Confidence Scores**
667
- - **80-100%**: Strong match, high confidence
668
- - **60-79%**: Good match, moderate confidence
669
- - **40-59%**: Weak match, needs review
670
- - **0-39%**: Poor match, likely exclude
671
 
672
- ### **Tips for Best Results**
673
- - Include both inclusion AND exclusion criteria
674
- - Specify population, intervention, and outcomes
675
- - Mention specific study designs to include/exclude
676
- - The AI works best with complete abstracts
 
 
 
677
  """)
678
 
679
- # Event handlers
680
  def update_stage1_columns(file):
681
  if file is None:
682
  return gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
@@ -689,31 +727,50 @@ EXCLUDE:
689
  except:
690
  return gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
691
 
692
- stage1_file.change(
693
- fn=update_stage1_columns,
694
- inputs=[stage1_file],
695
- outputs=[stage1_title_col, stage1_abstract_col]
696
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
- def process_with_download(*args):
699
- summary, table, csv_data = process_stage1_advanced(*args)
 
 
 
 
700
  return summary, table, csv_data, gr.DownloadButton(visible=bool(csv_data))
701
 
702
  stage1_process_btn.click(
703
- fn=process_with_download,
704
  inputs=[stage1_file, stage1_title_col, stage1_abstract_col, stage1_criteria, stage1_sample],
705
  outputs=[stage1_results, stage1_table, stage1_download_data, stage1_download_btn]
706
  )
707
 
708
- stage1_download_btn.click(
709
- lambda data: data,
710
- inputs=[stage1_download_data],
711
- outputs=[gr.File()]
712
  )
 
 
 
713
 
714
  return interface
715
 
716
  if __name__ == "__main__":
717
- print("Starting Advanced Medical Literature Screening System...")
718
- interface = create_advanced_interface()
719
  interface.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import requests
4
+ import json
5
+ from transformers import pipeline, AutoTokenizer, AutoModel
6
  import torch
 
 
 
 
 
 
7
  from sentence_transformers import SentenceTransformer, CrossEncoder
8
+ import time
9
+ from typing import List, Dict, Tuple
10
  import re
11
+ import numpy as np
 
 
12
 
13
  # ============================================================================
14
+ # ADVANCED NLP MODELS INITIALIZATION
15
  # ============================================================================
16
 
17
+ print("Loading advanced models...")
18
+
19
+ # Initialize advanced models
20
+ try:
21
+ # Cross-encoder for accurate semantic similarity
22
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', max_length=512)
23
+
24
+ # Zero-shot classifier for criteria matching
25
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
26
+
27
+ # Medical sentence transformer
28
+ sentence_model = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb')
29
+
30
+ # PubMedBERT for medical text understanding
31
+ pubmed_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
32
+ pubmed_model = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
33
+
34
+ print("Advanced models loaded successfully!")
35
+ USE_ADVANCED_MODELS = True
36
+ except Exception as e:
37
+ print(f"Warning: Could not load advanced models, falling back to basic models. Error: {e}")
38
+ # Fallback to basic models
39
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
40
+ similarity_model = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
41
+ USE_ADVANCED_MODELS = False
42
+ print("Basic models loaded successfully!")
43
+
44
+ # Medical terminology expansions
45
+ MEDICAL_SYNONYMS = {
46
+ 'rct': ['randomized controlled trial', 'randomised controlled trial', 'randomized clinical trial'],
47
+ 'pain': ['pain', 'nociception', 'analgesia', 'hyperalgesia', 'allodynia', 'neuropathic pain',
48
+ 'chronic pain', 'acute pain', 'postoperative pain', 'pain management'],
49
+ 'surgery': ['surgery', 'surgical', 'operation', 'operative', 'postoperative', 'perioperative',
50
+ 'preoperative', 'surgical procedure', 'surgical intervention'],
51
+ 'study design': ['study design', 'trial design', 'research design', 'methodology',
52
+ 'randomized', 'controlled', 'cohort', 'case-control', 'cross-sectional'],
53
+ }
54
+
55
+ # ============================================================================
56
+ # ADVANCED NLP FUNCTIONS
57
+ # ============================================================================
58
+
59
+ def expand_medical_terms(term: str) -> List[str]:
60
+ """Expand medical terms with synonyms"""
61
+ term_lower = term.lower()
62
+ expanded = [term]
63
+
64
+ for key, synonyms in MEDICAL_SYNONYMS.items():
65
+ if key in term_lower or any(syn in term_lower for syn in synonyms):
66
+ expanded.extend(synonyms[:3]) # Limit expansion
67
+
68
+ return list(set(expanded))
69
+
70
+ def cross_encoder_score(text: str, criteria: str) -> float:
71
+ """Calculate cross-encoder similarity score"""
72
+ if not USE_ADVANCED_MODELS:
73
+ return 0.5 # Default score if not available
74
+ try:
75
+ score = cross_encoder.predict([[text, criteria]])
76
+ return float(1 / (1 + np.exp(-score[0])))
77
+ except:
78
+ return 0.5
79
+
80
+ def get_pubmed_embedding(text: str) -> np.ndarray:
81
+ """Get PubMedBERT embedding for medical text"""
82
+ if not USE_ADVANCED_MODELS:
83
+ return np.zeros(768)
84
+
85
+ try:
86
+ inputs = pubmed_tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
 
 
 
 
87
  with torch.no_grad():
88
+ outputs = pubmed_model(**inputs)
 
89
  embedding = outputs.last_hidden_state[:, 0, :].numpy()
 
90
  return embedding.squeeze()
91
+ except:
92
+ return np.zeros(768)
93
 
94
+ def zero_shot_classify(text: str, labels: List[str], hypothesis_template: str = "This study is about {}") -> Dict:
95
+ """Perform zero-shot classification"""
96
+ if not labels:
97
+ return {}
98
+
99
+ try:
100
+ result = classifier(text, candidate_labels=labels[:10], hypothesis_template=hypothesis_template, multi_label=True)
101
+ scores = {}
102
+ for label, score in zip(result['labels'], result['scores']):
103
+ scores[label] = score
104
+ return scores
105
+ except:
106
+ return {}
107
+
108
+ # ============================================================================
109
+ # ENHANCED CRITERIA PARSING
110
+ # ============================================================================
111
+
112
+ def parse_criteria(criteria_text: str, stage: str = "stage1") -> Dict:
113
+ """Parse criteria with medical term expansion"""
114
+ criteria = {
115
+ 'population': [], 'intervention': [], 'comparator': [], 'outcomes': [],
116
+ 'study_design': [], 'include_general': [], 'exclude_general': []
117
+ }
118
+
119
+ lines = criteria_text.lower().split('\n')
120
+ current_section = None
121
+
122
+ for line in lines:
123
+ line = line.strip()
124
+ if not line:
125
+ continue
126
+
127
+ # Detect section headers
128
+ if any(keyword in line for keyword in ['population:', 'participants:', 'subjects:']):
129
+ current_section = 'population'
130
+ elif any(keyword in line for keyword in ['intervention:', 'exposure:', 'treatment:']):
131
+ current_section = 'intervention'
132
+ elif any(keyword in line for keyword in ['comparator:', 'control:', 'comparison:']):
133
+ current_section = 'comparator'
134
+ elif any(keyword in line for keyword in ['outcomes:', 'endpoint:', 'results:']):
135
+ current_section = 'outcomes'
136
+ elif any(keyword in line for keyword in ['study design:', 'design:', 'study type:']):
137
+ current_section = 'study_design'
138
+ elif 'include' in line and ':' in line:
139
+ current_section = 'include_general'
140
+ elif 'exclude' in line and ':' in line:
141
+ current_section = 'exclude_general'
142
+ elif line.startswith('-') and current_section:
143
+ term = line[1:].strip()
144
+ if term and len(term) > 2:
145
+ # Expand medical terms if advanced models are available
146
+ if USE_ADVANCED_MODELS:
147
+ expanded = expand_medical_terms(term)
148
+ criteria[current_section].extend(expanded)
149
+ else:
150
+ criteria[current_section].append(term)
151
+ elif current_section and not any(keyword in line for keyword in ['include', 'exclude', 'population', 'intervention', 'comparator', 'outcomes', 'study']):
152
+ terms = [t.strip() for t in line.split(',') if t.strip() and len(t.strip()) > 2]
153
+ if USE_ADVANCED_MODELS:
154
+ for term in terms:
155
+ expanded = expand_medical_terms(term)
156
+ criteria[current_section].extend(expanded)
157
+ else:
158
+ criteria[current_section].extend(terms)
159
+
160
+ # Remove duplicates
161
+ for key in criteria:
162
+ criteria[key] = list(set(criteria[key]))
163
+
164
+ return criteria
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ # ============================================================================
167
+ # ENHANCED STAGE 1 CLASSIFICATION
168
+ # ============================================================================
169
+
170
+ def semantic_similarity_score(study_text: str, criteria_terms: List[str]) -> Tuple[float, str]:
171
+ """Calculate semantic similarity with advanced models if available"""
172
+ if not criteria_terms:
173
+ return 0.0, ""
174
+
175
+ best_score, best_match = 0.0, ""
176
+
177
+ if USE_ADVANCED_MODELS:
178
+ # Use cross-encoder for more accurate matching
179
+ for term in criteria_terms[:5]: # Limit to avoid slowdown
180
+ score = cross_encoder_score(study_text, term)
181
+ if score > best_score:
182
+ best_score, best_match = score, term
183
+ else:
184
+ # Fallback to basic embedding similarity
185
+ study_embedding = get_text_embedding(study_text)
186
+ for term in criteria_terms:
187
+ term_embedding = get_text_embedding(term)
188
+ similarity = cosine_similarity(study_embedding, term_embedding)
189
+ if similarity > best_score:
190
+ best_score, best_match = similarity, term
191
+
192
+ return best_score, best_match
193
+
194
+ def cosine_similarity(a, b):
195
+ """Simple cosine similarity calculation"""
196
+ dot_product = np.dot(a, b)
197
+ norm_a = np.linalg.norm(a)
198
+ norm_b = np.linalg.norm(b)
199
+ return dot_product / (norm_a * norm_b) if norm_a > 0 and norm_b > 0 else 0
200
+
201
+ def get_text_embedding(text):
202
+ """Get text embedding using the similarity model"""
203
+ if USE_ADVANCED_MODELS:
204
  try:
205
+ embedding = sentence_model.encode(text)
206
+ return embedding
 
207
  except:
208
+ return np.zeros(384)
209
+ else:
 
 
 
 
 
210
  try:
211
+ if 'similarity_model' in globals():
212
+ embeddings = similarity_model(text)
213
+ return np.mean(embeddings[0], axis=0)
214
+ else:
215
+ return np.zeros(384)
 
 
 
 
 
 
 
216
  except:
217
+ return np.zeros(384)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ def stage1_classification(title: str, abstract: str, criteria_text: str) -> Dict:
220
+ """Enhanced Stage 1 classification with advanced NLP when available"""
221
+
222
+ study_text = f"{title} {abstract}".lower()
223
+ if len(study_text.strip()) < 20:
224
+ return {'decision': 'UNCLEAR', 'confidence': 20, 'reasoning': 'Insufficient text', 'stage': 1}
225
+
226
+ criteria = parse_criteria(criteria_text, "stage1")
227
+
228
+ # Use zero-shot classification if available with advanced models
229
+ if USE_ADVANCED_MODELS and criteria['include_general']:
230
+ zs_scores = zero_shot_classify(
231
+ study_text,
232
+ criteria['include_general'][:5],
233
+ "This study is relevant to {}"
234
+ )
235
+ if zs_scores:
236
+ max_zs_score = max(zs_scores.values())
237
+ if max_zs_score > 0.7:
238
+ return {
239
+ 'decision': 'INCLUDE',
240
+ 'confidence': min(int(max_zs_score * 100), 85),
241
+ 'reasoning': f"Stage 1 INCLUDE: High relevance to inclusion criteria ({max_zs_score:.2f})",
242
+ 'stage': 1
243
+ }
244
+
245
+ # Calculate PICOS scores with appropriate thresholds
246
+ pop_score, pop_match = semantic_similarity_score(study_text, criteria['population'])
247
+ int_score, int_match = semantic_similarity_score(study_text, criteria['intervention'])
248
+ out_score, out_match = semantic_similarity_score(study_text, criteria['outcomes'])
249
+ design_score, design_match = semantic_similarity_score(study_text, criteria['study_design'])
250
+ inc_score, inc_match = semantic_similarity_score(study_text, criteria['include_general'])
251
+ exc_score, exc_match = semantic_similarity_score(study_text, criteria['exclude_general'])
252
+
253
+ # Adjust thresholds based on model availability
254
+ threshold = 0.4 if USE_ADVANCED_MODELS else 0.25
255
+
256
+ reasoning_parts = []
257
+ if pop_score > threshold: reasoning_parts.append(f"Population: '{pop_match}' ({pop_score:.2f})")
258
+ if int_score > threshold: reasoning_parts.append(f"Intervention: '{int_match}' ({int_score:.2f})")
259
+ if out_score > threshold: reasoning_parts.append(f"Outcome: '{out_match}' ({out_score:.2f})")
260
+ if design_score > threshold: reasoning_parts.append(f"Design: '{design_match}' ({design_score:.2f})")
261
+ if inc_score > threshold: reasoning_parts.append(f"Include: '{inc_match}' ({inc_score:.2f})")
262
+ if exc_score > threshold: reasoning_parts.append(f"Exclude: '{exc_match}' ({exc_score:.2f})")
263
+
264
+ # Decision Logic
265
+ exc_threshold = 0.5 if USE_ADVANCED_MODELS else 0.35
266
+ if exc_score > exc_threshold:
267
+ decision, confidence = 'EXCLUDE', min(int(exc_score * 100), 90)
268
+ reasoning = f"Stage 1 EXCLUDE: {'; '.join(reasoning_parts)}"
269
+ elif sum([pop_score > threshold, int_score > threshold, out_score > threshold]) >= 2 and USE_ADVANCED_MODELS:
270
+ avg_score = np.mean([s for s in [pop_score, int_score, out_score, design_score, inc_score] if s > threshold])
271
+ decision, confidence = 'INCLUDE', min(int(avg_score * 85), 85)
272
+ reasoning = f"Stage 1 INCLUDE (Advanced): {'; '.join(reasoning_parts)}"
273
+ elif sum([pop_score > 0.25, int_score > 0.25, out_score > 0.25]) >= 1:
274
+ avg_score = np.mean([s for s in [pop_score, int_score, out_score, design_score, inc_score] if s > 0.25])
275
+ decision, confidence = 'INCLUDE', min(int(avg_score * 75), 80)
276
+ reasoning = f"Stage 1 INCLUDE: {'; '.join(reasoning_parts)}"
277
+ else:
278
+ decision, confidence = 'UNCLEAR', 40
279
+ reasoning = f"Stage 1 UNCLEAR: {'; '.join(reasoning_parts) if reasoning_parts else 'No clear matches'}"
280
+
281
+ return {'decision': decision, 'confidence': confidence, 'reasoning': reasoning, 'stage': 1}
282
 
283
  # ============================================================================
284
+ # STAGE 2 CLASSIFICATION (keeping original)
285
  # ============================================================================
286
 
287
+ def stage2_classification(title: str, abstract: str, full_text: str, criteria_text: str,
288
+ data_extraction_fields: Dict = None) -> Dict:
289
+ """Stage 2: Detailed full-text screening with data extraction"""
290
+
291
+ # Combine all available text
292
+ study_text = f"{title} {abstract} {full_text}".lower()
293
+
294
+ if len(study_text.strip()) < 50:
295
+ return {'decision': 'UNCLEAR', 'confidence': 25, 'reasoning': 'Insufficient full text', 'stage': 2}
296
+
297
+ criteria = parse_criteria(criteria_text, "stage2")
298
+
299
+ # More stringent scoring for Stage 2
300
+ pop_score, pop_match = semantic_similarity_score(study_text, criteria['population'])
301
+ int_score, int_match = semantic_similarity_score(study_text, criteria['intervention'])
302
+ comp_score, comp_match = semantic_similarity_score(study_text, criteria['comparator'])
303
+ out_score, out_match = semantic_similarity_score(study_text, criteria['outcomes'])
304
+ design_score, design_match = semantic_similarity_score(study_text, criteria['study_design'])
305
+ exc_score, exc_match = semantic_similarity_score(study_text, criteria['exclude_general'])
306
+
307
+ # Data extraction scoring
308
+ extraction_scores = {}
309
+ if data_extraction_fields:
310
+ for field, terms in data_extraction_fields.items():
311
+ if terms:
312
+ field_score, field_match = semantic_similarity_score(study_text, terms)
313
+ extraction_scores[field] = {'score': field_score, 'match': field_match}
314
+
315
+ reasoning_parts = []
316
+ if pop_score > 0.3: reasoning_parts.append(f"Population: '{pop_match}' ({pop_score:.2f})")
317
+ if int_score > 0.3: reasoning_parts.append(f"Intervention: '{int_match}' ({int_score:.2f})")
318
+ if comp_score > 0.3: reasoning_parts.append(f"Comparator: '{comp_match}' ({comp_score:.2f})")
319
+ if out_score > 0.3: reasoning_parts.append(f"Outcome: '{out_match}' ({out_score:.2f})")
320
+ if design_score > 0.3: reasoning_parts.append(f"Design: '{design_match}' ({design_score:.2f})")
321
+ if exc_score > 0.3: reasoning_parts.append(f"Exclusion: '{exc_match}' ({exc_score:.2f})")
322
+
323
+ # Stage 2 Decision Logic (High Specificity)
324
+ if exc_score > 0.4:
325
+ decision, confidence = 'EXCLUDE', min(int(exc_score * 100), 95)
326
+ reasoning = f"Stage 2 EXCLUDE: {'; '.join(reasoning_parts)}"
327
+ elif sum([pop_score > 0.4, int_score > 0.4, out_score > 0.4, design_score > 0.4]) >= 3:
328
+ avg_score = np.mean([pop_score, int_score, comp_score, out_score, design_score])
329
+ decision, confidence = 'INCLUDE', min(int(avg_score * 85), 92)
330
+ reasoning = f"Stage 2 INCLUDE: {'; '.join(reasoning_parts)}"
331
+ elif max(pop_score, int_score, out_score) > 0.5:
332
+ decision, confidence = 'INCLUDE', min(int(max(pop_score, int_score, out_score) * 80), 88)
333
+ reasoning = f"Stage 2 INCLUDE: {'; '.join(reasoning_parts)}"
334
+ else:
335
+ decision, confidence = 'EXCLUDE', 60
336
+ reasoning = f"Stage 2 EXCLUDE: Insufficient criteria match. {'; '.join(reasoning_parts)}"
337
+
338
+ result = {
339
+ 'decision': decision,
340
+ 'confidence': confidence,
341
+ 'reasoning': reasoning,
342
+ 'stage': 2,
343
+ 'extraction_data': extraction_scores
344
+ }
345
+
346
+ return result
347
 
348
+ # ============================================================================
349
+ # PROCESSING FUNCTIONS (keeping original structure)
350
+ # ============================================================================
 
 
 
351
 
352
+ def process_stage1(file, title_col, abstract_col, criteria, sample_size):
353
+ """Process Stage 1 screening with enhanced NLP"""
354
  try:
 
 
 
 
355
  df = pd.read_csv(file.name)
356
  if sample_size < len(df):
357
  df = df.head(sample_size)
 
364
  if not title and not abstract:
365
  continue
366
 
367
+ classification = stage1_classification(title, abstract, criteria)
 
368
 
369
  result = {
370
  'Study_ID': idx + 1,
371
  'Title': title[:100] + "..." if len(title) > 100 else title,
372
  'Stage1_Decision': classification['decision'],
373
  'Stage1_Confidence': f"{classification['confidence']}%",
 
 
374
  'Stage1_Reasoning': classification['reasoning'],
375
  'Ready_for_Stage2': 'Yes' if classification['decision'] == 'INCLUDE' else 'No',
376
  'Full_Title': title,
 
380
 
381
  results_df = pd.DataFrame(results)
382
 
383
+ # Summary for Stage 1
384
  total = len(results_df)
385
  included = len(results_df[results_df['Stage1_Decision'] == 'INCLUDE'])
386
  excluded = len(results_df[results_df['Stage1_Decision'] == 'EXCLUDE'])
387
  unclear = len(results_df[results_df['Stage1_Decision'] == 'UNCLEAR'])
388
 
389
+ model_info = "**Using Advanced Medical NLP Models**" if USE_ADVANCED_MODELS else "**Using Basic NLP Models**"
 
 
390
 
391
  summary = f"""
392
+ ## πŸ“Š Stage 1 (Title/Abstract) Results
 
 
 
 
 
 
393
 
394
+ {model_info}
 
395
 
396
+ **Screening Complete:**
397
+ - **Total Studies:** {total}
398
+ - **Include for Stage 2:** {included} ({included/total*100:.1f}%)
399
+ - **Exclude:** {excluded} ({excluded/total*100:.1f}%)
400
+ - **Needs Manual Review:** {unclear} ({unclear/total*100:.1f}%)
401
 
402
  **Next Steps:**
403
  1. Review {unclear} studies marked as UNCLEAR
404
  2. Proceed to Stage 2 with {included} included studies
405
+ 3. Obtain full texts for Stage 2 screening
406
  """
407
 
408
  return summary, results_df, results_df.to_csv(index=False)
 
410
  except Exception as e:
411
  return f"Error: {str(e)}", None, ""
412
 
413
+ def process_stage2(file, title_col, abstract_col, fulltext_col, criteria, extraction_fields, sample_size):
414
+ """Process Stage 2 screening with data extraction"""
415
+ try:
416
+ df = pd.read_csv(file.name)
417
 
418
+ # Filter to only Stage 1 included studies if column exists
419
+ if 'Stage1_Decision' in df.columns:
420
+ df = df[df['Stage1_Decision'] == 'INCLUDE']
421
+
422
+ if sample_size < len(df):
423
+ df = df.head(sample_size)
424
+
425
+ # Parse extraction fields
426
+ extraction_dict = {}
427
+ if extraction_fields:
428
+ for line in extraction_fields.split('\n'):
429
+ if ':' in line:
430
+ field, terms = line.split(':', 1)
431
+ extraction_dict[field.strip()] = [t.strip() for t in terms.split(',') if t.strip()]
432
+
433
+ results = []
434
+ for idx, row in df.iterrows():
435
+ title = str(row[title_col]) if pd.notna(row[title_col]) else ""
436
+ abstract = str(row[abstract_col]) if pd.notna(row[abstract_col]) else ""
437
+ full_text = str(row[fulltext_col]) if fulltext_col and fulltext_col in df.columns and pd.notna(row[fulltext_col]) else ""
438
+
439
+ if not title and not abstract:
440
+ continue
441
+
442
+ classification = stage2_classification(title, abstract, full_text, criteria, extraction_dict)
443
+
444
+ result = {
445
+ 'Study_ID': idx + 1,
446
+ 'Title': title[:100] + "..." if len(title) > 100 else title,
447
+ 'Stage2_Decision': classification['decision'],
448
+ 'Stage2_Confidence': f"{classification['confidence']}%",
449
+ 'Stage2_Reasoning': classification['reasoning'],
450
+ 'Final_Include': 'Yes' if classification['decision'] == 'INCLUDE' else 'No',
451
+ 'Extraction_Data': str(classification.get('extraction_data', {})),
452
+ 'Full_Title': title,
453
+ 'Full_Abstract': abstract,
454
+ 'Full_Text': full_text
455
+ }
456
+ results.append(result)
457
+
458
+ results_df = pd.DataFrame(results)
459
+
460
+ # Summary for Stage 2
461
+ total = len(results_df)
462
+ final_included = len(results_df[results_df['Stage2_Decision'] == 'INCLUDE'])
463
+ final_excluded = len(results_df[results_df['Stage2_Decision'] == 'EXCLUDE'])
464
+
465
+ summary = f"""
466
+ ## πŸ“Š Stage 2 (Full-Text) Results
467
+
468
+ **Detailed Screening Complete:**
469
+ - **Studies Reviewed:** {total}
470
+ - **Final INCLUDE:** {final_included} ({final_included/total*100:.1f}%)
471
+ - **Final EXCLUDE:** {final_excluded} ({final_excluded/total*100:.1f}%)
472
+
473
+ **Ready for Next Steps:**
474
+ - **Data Extraction:** {final_included} studies
475
+ - **Quality Assessment:** {final_included} studies
476
+ - **Evidence Synthesis:** Ready to proceed
477
+
478
+ **Recommended Actions:**
479
+ 1. Export {final_included} included studies for detailed data extraction
480
+ 2. Conduct quality assessment (ROB2, ROBINS-I, etc.)
481
+ 3. Begin evidence synthesis and meta-analysis planning
482
+ """
483
+
484
+ return summary, results_df, results_df.to_csv(index=False)
485
+
486
+ except Exception as e:
487
+ return f"Error: {str(e)}", None, ""
488
+
489
+ # ============================================================================
490
+ # ORIGINAL INTERFACE (PRESERVED)
491
+ # ============================================================================
492
+
493
+ def create_interface():
494
+ with gr.Blocks(title="πŸ”¬ 2-Stage Systematic Review AI Assistant", theme=gr.themes.Soft()) as interface:
495
 
496
+ gr.Markdown("""
497
+ # πŸ”¬ 2-Stage Systematic Review AI Assistant
498
 
499
+ **Complete workflow for evidence-based systematic reviews**
 
 
 
 
500
 
501
+ This tool supports the full 2-stage systematic review process:
502
+ - **Stage 1:** Title/Abstract screening (high sensitivity)
503
+ - **Stage 2:** Full-text screening with data extraction (high specificity)
504
  """)
505
 
506
  with gr.Tabs():
507
 
508
  # STAGE 1 TAB
509
+ with gr.TabItem("πŸ“‹ Stage 1: Title/Abstract Screening"):
510
  with gr.Row():
511
  with gr.Column(scale=1):
512
  gr.Markdown("### πŸ“ Upload Study Data")
 
521
  stage1_title_col = gr.Dropdown(label="Title Column", choices=[], interactive=True)
522
  stage1_abstract_col = gr.Dropdown(label="Abstract Column", choices=[], interactive=True)
523
 
524
+ stage1_sample = gr.Slider(label="Studies to Process", minimum=5, maximum=500, value=100, step=5)
 
 
 
 
 
 
 
525
 
526
  with gr.Column(scale=1):
527
+ gr.Markdown("### 🎯 Stage 1 Criteria (Broad/Sensitive)")
528
 
529
  stage1_criteria = gr.Textbox(
530
+ label="Inclusion/Exclusion Criteria for Stage 1",
531
  value="""POPULATION:
532
+ - Adult participants
533
+ - Human studies
 
534
 
535
  INTERVENTION:
536
+ - [Your intervention/exposure of interest]
 
 
537
 
538
  OUTCOMES:
539
+ - [Primary outcomes of interest]
 
 
 
540
 
541
  STUDY DESIGN:
542
  - Randomized controlled trials
 
543
  - Cohort studies
544
+ - Case-control studies
545
 
546
  EXCLUDE:
547
  - Animal studies
 
548
  - Case reports
549
+ - Reviews (unless relevant)""",
550
+ lines=15
 
551
  )
552
 
553
+ stage1_process_btn = gr.Button("πŸš€ Start Stage 1 Screening", variant="primary")
 
 
 
 
 
 
554
 
555
  stage1_results = gr.Markdown()
556
+ stage1_table = gr.Dataframe(label="Stage 1 Results")
 
 
 
557
  stage1_download_data = gr.Textbox(visible=False)
558
+ stage1_download_btn = gr.DownloadButton(label="πŸ’Ύ Download Stage 1 Results", visible=False)
 
 
 
559
 
560
+ # STAGE 2 TAB
561
+ with gr.TabItem("πŸ“„ Stage 2: Full-Text Screening"):
562
+ with gr.Row():
563
+ with gr.Column(scale=1):
564
+ gr.Markdown("### πŸ“ Upload Stage 1 Results or Full-Text Data")
565
+
566
+ stage2_file = gr.File(
567
+ label="Upload Stage 1 Results or Studies with Full Text",
568
+ file_types=[".csv"],
569
+ type="filepath"
570
+ )
571
+
572
+ with gr.Row():
573
+ stage2_title_col = gr.Dropdown(label="Title Column", choices=[], interactive=True)
574
+ stage2_abstract_col = gr.Dropdown(label="Abstract Column", choices=[], interactive=True)
575
+
576
+ stage2_fulltext_col = gr.Dropdown(label="Full Text Column", choices=[], interactive=True)
577
+ stage2_sample = gr.Slider(label="Studies to Process", minimum=5, maximum=200, value=50, step=5)
578
+
579
+ with gr.Column(scale=1):
580
+ gr.Markdown("### 🎯 Stage 2 Criteria (Strict/Specific)")
581
+
582
+ stage2_criteria = gr.Textbox(
583
+ label="Detailed Inclusion/Exclusion Criteria for Stage 2",
584
+ value="""POPULATION:
585
+ - [Specific population criteria]
586
+ - [Age ranges, conditions, etc.]
587
+
588
+ INTERVENTION:
589
+ - [Detailed intervention specifications]
590
+ - [Dosage, duration, delivery method]
591
+
592
+ COMPARATOR:
593
+ - [Control group specifications]
594
+ - [Placebo, standard care, etc.]
595
+
596
+ OUTCOMES:
597
+ - [Primary endpoint definitions]
598
+ - [Secondary outcomes]
599
+ - [Measurement methods]
600
+
601
+ STUDY DESIGN:
602
+ - [Minimum study quality requirements]
603
+ - [Follow-up duration requirements]
604
+
605
+ EXCLUDE:
606
+ - [Specific exclusion criteria]
607
+ - [Study quality thresholds]""",
608
+ lines=15
609
+ )
610
+
611
+ extraction_fields = gr.Textbox(
612
+ label="Data Extraction Fields (Optional)",
613
+ value="""Sample Size: participants, subjects, patients, n=
614
+ Intervention Duration: weeks, months, days, duration
615
+ Primary Outcome: endpoint, primary outcome, main outcome
616
+ Statistical Method: analysis, statistical, regression, model
617
+ Risk of Bias: randomization, blinding, allocation""",
618
+ lines=8
619
+ )
620
+
621
+ stage2_process_btn = gr.Button("πŸ” Start Stage 2 Screening", variant="primary")
622
+
623
+ stage2_results = gr.Markdown()
624
+ stage2_table = gr.Dataframe(label="Stage 2 Results with Data Extraction")
625
+ stage2_download_data = gr.Textbox(visible=False)
626
+ stage2_download_btn = gr.DownloadButton(label="πŸ’Ύ Download Final Results", visible=False)
627
+
628
+ # WORKFLOW GUIDANCE TAB
629
+ with gr.TabItem("πŸ“š Systematic Review Workflow"):
630
  gr.Markdown("""
631
+ ## πŸ”„ Complete 2-Stage Systematic Review Process
632
+
633
+ ### **Stage 1: Title/Abstract Screening**
634
+ **Objective:** High sensitivity screening to identify potentially relevant studies
635
+
636
+ **Process:**
637
+ 1. Upload search results from multiple databases (PubMed, Embase, etc.)
638
+ 2. Define broad inclusion/exclusion criteria
639
+ 3. AI screens titles/abstracts with high sensitivity
640
+ 4. Manually review "UNCLEAR" classifications
641
+ 5. Export studies marked for inclusion to Stage 2
642
+
643
+ **Criteria Guidelines:**
644
+ - Use broad terms to capture all potentially relevant studies
645
+ - Focus on key PICOS elements (Population, Intervention, Outcomes)
646
+ - Err on the side of inclusion when uncertain
647
+
648
+ ### **Stage 2: Full-Text Screening**
649
+ **Objective:** High specificity screening with detailed data extraction
650
+
651
+ **Process:**
652
+ 1. Upload Stage 1 results or add full-text content
653
+ 2. Define strict, specific inclusion/exclusion criteria
654
+ 3. AI performs detailed full-text analysis
655
+ 4. Extract key data points for synthesis
656
+ 5. Export final included studies for meta-analysis
657
+
658
+ **Criteria Guidelines:**
659
+ - Use specific, measurable criteria
660
+ - Include detailed PICOS specifications
661
+ - Define minimum quality thresholds
662
+ - Specify exact outcome measurements needed
663
+
664
+ ### **Quality Assurance Recommendations:**
665
+
666
+ **For Stage 1:**
667
+ - Manual review of 10-20% of AI decisions
668
+ - Inter-rater reliability testing with subset
669
+ - Calibration exercises among reviewers
670
+
671
+ **For Stage 2:**
672
+ - Manual validation of all AI INCLUDE decisions
673
+ - Detailed reason documentation for exclusions
674
+ - Data extraction verification by second reviewer
675
+
676
+ ### **After 2-Stage Screening:**
677
+
678
+ 1. **Data Extraction:** Extract detailed study characteristics
679
+ 2. **Quality Assessment:** Apply ROB2, ROBINS-I, or other tools
680
+ 3. **Evidence Synthesis:** Qualitative synthesis and meta-analysis
681
+ 4. **GRADE Assessment:** Evaluate certainty of evidence
682
+ 5. **Reporting:** Follow PRISMA guidelines
683
 
684
+ ### **Best Practices:**
 
 
 
 
 
685
 
686
+ - **Document everything:** Keep detailed logs of decisions and criteria
687
+ - **Validate AI decisions:** Use AI as assistance, not replacement
688
+ - **Follow guidelines:** Adhere to Cochrane and PRISMA standards
689
+ - **Test criteria:** Pilot with known studies before full screening
690
+ - **Multiple reviewers:** Have disagreements resolved by third reviewer
691
 
692
+ ### **When to Use Each Stage:**
 
 
 
693
 
694
+ **Use Stage 1 when:**
695
+ - Starting with large search results (>1000 studies)
696
+ - Need to quickly filter irrelevant studies
697
+ - Working with title/abstract data only
698
 
699
+ **Use Stage 2 when:**
700
+ - Have full-text access to studies
701
+ - Need detailed inclusion/exclusion assessment
702
+ - Ready for data extraction
703
+ - Preparing for meta-analysis
704
 
705
+ ### **Advanced NLP Features:**
 
 
 
 
706
 
707
+ This tool now includes advanced medical NLP models when available:
708
+ - **PubMedBERT** for medical text understanding
709
+ - **Cross-encoders** for accurate semantic matching
710
+ - **Zero-shot classification** for flexible criteria
711
+ - **Medical term expansion** for comprehensive matching
712
+
713
+ The system automatically detects and uses advanced models when available,
714
+ falling back to basic models if needed.
715
  """)
716
 
717
+ # Event handlers for file uploads and column detection
718
  def update_stage1_columns(file):
719
  if file is None:
720
  return gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
 
727
  except:
728
  return gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
729
 
730
+ def update_stage2_columns(file):
731
+ if file is None:
732
+ return gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
733
+ try:
734
+ df = pd.read_csv(file.name)
735
+ columns = df.columns.tolist()
736
+ title_col = next((col for col in columns if 'title' in col.lower()), columns[0] if columns else None)
737
+ abstract_col = next((col for col in columns if 'abstract' in col.lower()), columns[1] if len(columns) > 1 else None)
738
+ fulltext_col = next((col for col in columns if any(term in col.lower() for term in ['full_text', 'fulltext', 'text', 'content'])), None)
739
+ return (gr.Dropdown(choices=columns, value=title_col),
740
+ gr.Dropdown(choices=columns, value=abstract_col),
741
+ gr.Dropdown(choices=columns, value=fulltext_col))
742
+ except:
743
+ return gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
744
+
745
+ # Event bindings
746
+ stage1_file.change(fn=update_stage1_columns, inputs=[stage1_file], outputs=[stage1_title_col, stage1_abstract_col])
747
+ stage2_file.change(fn=update_stage2_columns, inputs=[stage2_file], outputs=[stage2_title_col, stage2_abstract_col, stage2_fulltext_col])
748
 
749
+ def process_stage1_with_download(*args):
750
+ summary, table, csv_data = process_stage1(*args)
751
+ return summary, table, csv_data, gr.DownloadButton(visible=bool(csv_data))
752
+
753
+ def process_stage2_with_download(*args):
754
+ summary, table, csv_data = process_stage2(*args)
755
  return summary, table, csv_data, gr.DownloadButton(visible=bool(csv_data))
756
 
757
  stage1_process_btn.click(
758
+ fn=process_stage1_with_download,
759
  inputs=[stage1_file, stage1_title_col, stage1_abstract_col, stage1_criteria, stage1_sample],
760
  outputs=[stage1_results, stage1_table, stage1_download_data, stage1_download_btn]
761
  )
762
 
763
+ stage2_process_btn.click(
764
+ fn=process_stage2_with_download,
765
+ inputs=[stage2_file, stage2_title_col, stage2_abstract_col, stage2_fulltext_col, stage2_criteria, extraction_fields, stage2_sample],
766
+ outputs=[stage2_results, stage2_table, stage2_download_data, stage2_download_btn]
767
  )
768
+
769
+ stage1_download_btn.click(lambda data: data, inputs=[stage1_download_data], outputs=[gr.File()])
770
+ stage2_download_btn.click(lambda data: data, inputs=[stage2_download_data], outputs=[gr.File()])
771
 
772
  return interface
773
 
774
  if __name__ == "__main__":
775
+ interface = create_interface()
 
776
  interface.launch()