Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
import re | |
import logging | |
import os | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler(), | |
logging.FileHandler('medgemma_app.log') | |
] | |
) | |
logger = logging.getLogger(__name__) | |
class MedGemmaSymptomAnalyzer: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.model_loaded = False | |
logger.info("Initializing MedGemma Symptom Analyzer...") | |
def load_model(self): | |
"""Load MedGemma model with optimizations for deployment and CPU compatibility""" | |
if self.model_loaded: | |
return True | |
model_name = "google/medgemma-4b-it" | |
logger.info(f"Loading model: {model_name}") | |
# Detect available device and log system info | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Device detected: {device}") | |
if device == "cpu": | |
logger.info(f"CPU threads available: {torch.get_num_threads()}") | |
else: | |
logger.info(f"CUDA device: {torch.cuda.get_device_name()}") | |
try: | |
# Get HF token from environment (set in Hugging Face Spaces secrets) | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
logger.info("Using HF_TOKEN for authentication") | |
else: | |
logger.warning("HF_TOKEN not found in environment variables") | |
# Optimize settings based on device | |
if device == "cpu": | |
logger.info("Configuring for CPU-optimized loading...") | |
torch_dtype = torch.float32 | |
device_map = "cpu" | |
# Set optimal number of threads for CPU inference | |
torch.set_num_threads(max(1, torch.get_num_threads() // 2)) | |
# Additional CPU optimizations | |
import psutil | |
available_memory_gb = psutil.virtual_memory().available / (1024**3) | |
logger.info(f"Available memory: {available_memory_gb:.1f} GB") | |
# Enable memory-efficient loading for low-memory systems | |
cpu_loading_kwargs = { | |
"low_cpu_mem_usage": True, | |
"torch_dtype": torch_dtype, | |
"device_map": device_map | |
} | |
# Use offloading for very low memory systems (< 8GB available) | |
if available_memory_gb < 8: | |
logger.warning("Low memory detected, enabling aggressive memory optimizations") | |
cpu_loading_kwargs.update({ | |
"offload_folder": "/tmp/model_offload", | |
"offload_state_dict": True | |
}) | |
else: | |
logger.info("Configuring for GPU loading...") | |
torch_dtype = torch.float16 | |
device_map = "auto" | |
cpu_loading_kwargs = { | |
"torch_dtype": torch_dtype, | |
"device_map": device_map, | |
"low_cpu_mem_usage": True | |
} | |
logger.info("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
token=hf_token, | |
use_fast=True # Use fast tokenizer for better performance | |
) | |
logger.info(f"Loading model with dtype={torch_dtype}, device_map={device_map}...") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=hf_token, | |
trust_remote_code=False, # Security best practice | |
**cpu_loading_kwargs | |
) | |
# Ensure pad token is set | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Move model to appropriate device if needed | |
if device == "cpu" and hasattr(self.model, 'to'): | |
self.model = self.model.to('cpu') | |
logger.info("Model moved to CPU") | |
self.model_loaded = True | |
logger.info(f"Model loaded successfully on {device}!") | |
return True | |
except torch.cuda.OutOfMemoryError as e: | |
logger.error(f"GPU out of memory: {str(e)}") | |
logger.info("Attempting CPU fallback due to GPU memory constraints...") | |
try: | |
# Force CPU loading if GPU fails | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=hf_token, | |
trust_remote_code=False, | |
torch_dtype=torch.float32, | |
device_map="cpu", | |
low_cpu_mem_usage=True | |
) | |
self.model_loaded = True | |
logger.info("Model loaded successfully on CPU after GPU failure!") | |
return True | |
except Exception as fallback_e: | |
logger.error(f"CPU fallback also failed: {str(fallback_e)}") | |
self.model = None | |
self.tokenizer = None | |
self.model_loaded = False | |
return False | |
except ImportError as e: | |
logger.error(f"Missing dependency for model loading: {str(e)}") | |
logger.info("Please ensure all required packages are installed: pip install -r requirements.txt") | |
self.model = None | |
self.tokenizer = None | |
self.model_loaded = False | |
return False | |
except OSError as e: | |
if "disk quota exceeded" in str(e).lower() or "no space left" in str(e).lower(): | |
logger.error("Insufficient disk space for model loading") | |
logger.info("Please free up disk space and try again") | |
elif "connection" in str(e).lower() or "timeout" in str(e).lower(): | |
logger.error("Network connection issue during model download") | |
logger.info("Please check your internet connection and try again") | |
else: | |
logger.error(f"OS error during model loading: {str(e)}") | |
self.model = None | |
self.tokenizer = None | |
self.model_loaded = False | |
return False | |
except Exception as e: | |
logger.error(f"Failed to load model {model_name}: {str(e)}", exc_info=True) | |
logger.warning("Falling back to demo mode due to model loading failure") | |
# Provide helpful troubleshooting info | |
if device == "cpu": | |
logger.info("CPU loading troubleshooting tips:") | |
logger.info("- Ensure sufficient RAM (minimum 8GB recommended)") | |
logger.info("- Check that PyTorch CPU version is installed") | |
logger.info("- Verify HuggingFace token is valid") | |
self.model = None | |
self.tokenizer = None | |
self.model_loaded = False | |
return False | |
def analyze_symptoms(self, symptoms_text, max_length=512, temperature=0.7): | |
"""Analyze symptoms and provide medical insights""" | |
# Try to load model if not already loaded | |
if not self.model_loaded: | |
if not self.load_model(): | |
# Fallback to demo response if model fails to load | |
return self._get_demo_response(symptoms_text) | |
if not self.model or not self.tokenizer: | |
return self._get_demo_response(symptoms_text) | |
# Format the prompt for medical analysis | |
prompt = f"""Patient presents with the following symptoms: {symptoms_text} | |
Based on these symptoms, provide a medical analysis including: | |
1. Possible differential diagnoses | |
2. Recommended next steps | |
3. When to seek immediate medical attention | |
Medical Analysis:""" | |
try: | |
# Tokenize input | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=max_length, | |
padding=True | |
) | |
# Generate response | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=400, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# Decode response | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the generated part (after the prompt) | |
generated_text = response[len(prompt):].strip() | |
return generated_text | |
except Exception as e: | |
return f"Error during analysis: {str(e)}" | |
def _get_demo_response(self, symptoms_text): | |
"""Provide a demo response when model is not available""" | |
symptoms_lower = symptoms_text.lower() | |
# Simple keyword-based demo responses | |
if any(word in symptoms_lower for word in ['fever', 'headache', 'fatigue', 'body aches']): | |
return """**DEMO MODE - Model not loaded** | |
Based on the symptoms described (fever, headache, fatigue), here's a general analysis: | |
**Possible Differential Diagnoses:** | |
1. Viral infection (common cold, flu) | |
2. Bacterial infection | |
3. Stress or exhaustion | |
4. Early signs of other conditions | |
**Recommended Next Steps:** | |
1. Rest and adequate hydration | |
2. Monitor temperature regularly | |
3. Over-the-counter pain relievers if appropriate | |
4. Observe for worsening symptoms | |
**When to Seek Immediate Medical Attention:** | |
- High fever (>101.3°F/38.5°C) | |
- Severe headache or neck stiffness | |
- Difficulty breathing | |
- Persistent vomiting | |
- Symptoms worsen rapidly | |
*Note: This is a demo response. For actual medical analysis, the MedGemma model needs to be loaded.*""" | |
elif any(word in symptoms_lower for word in ['chest pain', 'breathing', 'shortness']): | |
return """**DEMO MODE - Model not loaded** | |
Based on chest-related symptoms, here's a general analysis: | |
**Possible Differential Diagnoses:** | |
1. Respiratory infection | |
2. Muscle strain | |
3. Anxiety-related symptoms | |
4. More serious conditions requiring evaluation | |
**Recommended Next Steps:** | |
1. Seek medical evaluation promptly | |
2. Avoid strenuous activity | |
3. Monitor breathing patterns | |
4. Note any associated symptoms | |
**When to Seek Immediate Medical Attention:** | |
- Severe chest pain | |
- Difficulty breathing | |
- Pain spreading to arm, jaw, or back | |
- Dizziness or fainting | |
- These symptoms require immediate medical care | |
*Note: This is a demo response. For actual medical analysis, the MedGemma model needs to be loaded.*""" | |
else: | |
return f"""**DEMO MODE - Model not loaded** | |
Thank you for describing your symptoms. In demo mode, I can provide general guidance: | |
**General Recommendations:** | |
1. Keep track of when symptoms started | |
2. Note any factors that make symptoms better or worse | |
3. Stay hydrated and get adequate rest | |
4. Consider consulting a healthcare provider | |
**When to Seek Medical Attention:** | |
- Symptoms persist or worsen | |
- You develop new concerning symptoms | |
- You have underlying health conditions | |
- You're unsure about the severity | |
For a proper AI-powered analysis of your specific symptoms: "{symptoms_text[:100]}...", the MedGemma model would need to be successfully loaded. | |
*Note: This is a demo response. For actual medical analysis, the MedGemma model needs to be loaded.*""" | |
# Initialize the analyzer | |
analyzer = MedGemmaSymptomAnalyzer() | |
def analyze_symptoms_interface(symptoms, temperature): | |
"""Interface function for Gradio""" | |
if not symptoms.strip(): | |
return "Please enter some symptoms to analyze." | |
# Add medical disclaimer | |
disclaimer = """⚠️ **MEDICAL DISCLAIMER**: This analysis is for educational purposes only and should not replace professional medical advice. Always consult with a healthcare provider for proper diagnosis and treatment. | |
""" | |
analysis = analyzer.analyze_symptoms(symptoms, temperature=temperature) | |
return disclaimer + analysis | |
# Create Gradio interface | |
with gr.Blocks(title="MedGemma Symptom Analyzer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🏥 MedGemma Symptom Analyzer | |
This application uses Google's MedGemma model to provide preliminary analysis of medical symptoms. | |
Enter your symptoms below to get AI-powered insights. | |
**Important**: This tool is for educational purposes only and should not replace professional medical advice. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
symptoms_input = gr.Textbox( | |
label="Describe your symptoms", | |
placeholder="Example: I have been experiencing headaches, fever, and fatigue for the past 3 days...", | |
lines=5, | |
max_lines=10 | |
) | |
temperature_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Response Creativity (Temperature)", | |
info="Lower values = more focused, Higher values = more creative" | |
) | |
analyze_btn = gr.Button("🔍 Analyze Symptoms", variant="primary") | |
with gr.Column(scale=3): | |
output = gr.Textbox( | |
label="Medical Analysis", | |
lines=15, | |
max_lines=20, | |
interactive=False | |
) | |
# Example symptoms | |
gr.Markdown("### 📝 Example Symptoms to Try:") | |
with gr.Row(): | |
example1 = gr.Button("Flu-like symptoms", size="sm") | |
example2 = gr.Button("Chest pain", size="sm") | |
example3 = gr.Button("Digestive issues", size="sm") | |
# Event handlers | |
analyze_btn.click( | |
analyze_symptoms_interface, | |
inputs=[symptoms_input, temperature_slider], | |
outputs=output | |
) | |
example1.click( | |
lambda: "I have been experiencing fever, body aches, headache, and fatigue for the past 2 days. I also have a slight cough.", | |
outputs=symptoms_input | |
) | |
example2.click( | |
lambda: "I'm experiencing chest pain that worsens when I take deep breaths. It started this morning and is accompanied by shortness of breath.", | |
outputs=symptoms_input | |
) | |
example3.click( | |
lambda: "I have been having stomach pain, nausea, and diarrhea for the past day. I also feel bloated and have lost my appetite.", | |
outputs=symptoms_input | |
) | |
gr.Markdown(""" | |
### ⚠️ Important Notes: | |
- This is an AI tool for educational purposes only | |
- Always consult healthcare professionals for medical concerns | |
- Seek immediate medical attention for severe or emergency symptoms | |
- The AI may not always provide accurate medical information | |
""") | |
if __name__ == "__main__": | |
demo.launch() |