gnumanth's picture
chore: device optimization
a5da8f2 verified
raw
history blame
15.5 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import re
import logging
import os
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('medgemma_app.log')
]
)
logger = logging.getLogger(__name__)
class MedGemmaSymptomAnalyzer:
def __init__(self):
self.model = None
self.tokenizer = None
self.model_loaded = False
logger.info("Initializing MedGemma Symptom Analyzer...")
def load_model(self):
"""Load MedGemma model with optimizations for deployment and CPU compatibility"""
if self.model_loaded:
return True
model_name = "google/medgemma-4b-it"
logger.info(f"Loading model: {model_name}")
# Detect available device and log system info
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device detected: {device}")
if device == "cpu":
logger.info(f"CPU threads available: {torch.get_num_threads()}")
else:
logger.info(f"CUDA device: {torch.cuda.get_device_name()}")
try:
# Get HF token from environment (set in Hugging Face Spaces secrets)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
logger.info("Using HF_TOKEN for authentication")
else:
logger.warning("HF_TOKEN not found in environment variables")
# Optimize settings based on device
if device == "cpu":
logger.info("Configuring for CPU-optimized loading...")
torch_dtype = torch.float32
device_map = "cpu"
# Set optimal number of threads for CPU inference
torch.set_num_threads(max(1, torch.get_num_threads() // 2))
# Additional CPU optimizations
import psutil
available_memory_gb = psutil.virtual_memory().available / (1024**3)
logger.info(f"Available memory: {available_memory_gb:.1f} GB")
# Enable memory-efficient loading for low-memory systems
cpu_loading_kwargs = {
"low_cpu_mem_usage": True,
"torch_dtype": torch_dtype,
"device_map": device_map
}
# Use offloading for very low memory systems (< 8GB available)
if available_memory_gb < 8:
logger.warning("Low memory detected, enabling aggressive memory optimizations")
cpu_loading_kwargs.update({
"offload_folder": "/tmp/model_offload",
"offload_state_dict": True
})
else:
logger.info("Configuring for GPU loading...")
torch_dtype = torch.float16
device_map = "auto"
cpu_loading_kwargs = {
"torch_dtype": torch_dtype,
"device_map": device_map,
"low_cpu_mem_usage": True
}
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token,
use_fast=True # Use fast tokenizer for better performance
)
logger.info(f"Loading model with dtype={torch_dtype}, device_map={device_map}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
trust_remote_code=False, # Security best practice
**cpu_loading_kwargs
)
# Ensure pad token is set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Move model to appropriate device if needed
if device == "cpu" and hasattr(self.model, 'to'):
self.model = self.model.to('cpu')
logger.info("Model moved to CPU")
self.model_loaded = True
logger.info(f"Model loaded successfully on {device}!")
return True
except torch.cuda.OutOfMemoryError as e:
logger.error(f"GPU out of memory: {str(e)}")
logger.info("Attempting CPU fallback due to GPU memory constraints...")
try:
# Force CPU loading if GPU fails
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
trust_remote_code=False,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
)
self.model_loaded = True
logger.info("Model loaded successfully on CPU after GPU failure!")
return True
except Exception as fallback_e:
logger.error(f"CPU fallback also failed: {str(fallback_e)}")
self.model = None
self.tokenizer = None
self.model_loaded = False
return False
except ImportError as e:
logger.error(f"Missing dependency for model loading: {str(e)}")
logger.info("Please ensure all required packages are installed: pip install -r requirements.txt")
self.model = None
self.tokenizer = None
self.model_loaded = False
return False
except OSError as e:
if "disk quota exceeded" in str(e).lower() or "no space left" in str(e).lower():
logger.error("Insufficient disk space for model loading")
logger.info("Please free up disk space and try again")
elif "connection" in str(e).lower() or "timeout" in str(e).lower():
logger.error("Network connection issue during model download")
logger.info("Please check your internet connection and try again")
else:
logger.error(f"OS error during model loading: {str(e)}")
self.model = None
self.tokenizer = None
self.model_loaded = False
return False
except Exception as e:
logger.error(f"Failed to load model {model_name}: {str(e)}", exc_info=True)
logger.warning("Falling back to demo mode due to model loading failure")
# Provide helpful troubleshooting info
if device == "cpu":
logger.info("CPU loading troubleshooting tips:")
logger.info("- Ensure sufficient RAM (minimum 8GB recommended)")
logger.info("- Check that PyTorch CPU version is installed")
logger.info("- Verify HuggingFace token is valid")
self.model = None
self.tokenizer = None
self.model_loaded = False
return False
def analyze_symptoms(self, symptoms_text, max_length=512, temperature=0.7):
"""Analyze symptoms and provide medical insights"""
# Try to load model if not already loaded
if not self.model_loaded:
if not self.load_model():
# Fallback to demo response if model fails to load
return self._get_demo_response(symptoms_text)
if not self.model or not self.tokenizer:
return self._get_demo_response(symptoms_text)
# Format the prompt for medical analysis
prompt = f"""Patient presents with the following symptoms: {symptoms_text}
Based on these symptoms, provide a medical analysis including:
1. Possible differential diagnoses
2. Recommended next steps
3. When to seek immediate medical attention
Medical Analysis:"""
try:
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=max_length,
padding=True
)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=400,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the generated part (after the prompt)
generated_text = response[len(prompt):].strip()
return generated_text
except Exception as e:
return f"Error during analysis: {str(e)}"
def _get_demo_response(self, symptoms_text):
"""Provide a demo response when model is not available"""
symptoms_lower = symptoms_text.lower()
# Simple keyword-based demo responses
if any(word in symptoms_lower for word in ['fever', 'headache', 'fatigue', 'body aches']):
return """**DEMO MODE - Model not loaded**
Based on the symptoms described (fever, headache, fatigue), here's a general analysis:
**Possible Differential Diagnoses:**
1. Viral infection (common cold, flu)
2. Bacterial infection
3. Stress or exhaustion
4. Early signs of other conditions
**Recommended Next Steps:**
1. Rest and adequate hydration
2. Monitor temperature regularly
3. Over-the-counter pain relievers if appropriate
4. Observe for worsening symptoms
**When to Seek Immediate Medical Attention:**
- High fever (>101.3°F/38.5°C)
- Severe headache or neck stiffness
- Difficulty breathing
- Persistent vomiting
- Symptoms worsen rapidly
*Note: This is a demo response. For actual medical analysis, the MedGemma model needs to be loaded.*"""
elif any(word in symptoms_lower for word in ['chest pain', 'breathing', 'shortness']):
return """**DEMO MODE - Model not loaded**
Based on chest-related symptoms, here's a general analysis:
**Possible Differential Diagnoses:**
1. Respiratory infection
2. Muscle strain
3. Anxiety-related symptoms
4. More serious conditions requiring evaluation
**Recommended Next Steps:**
1. Seek medical evaluation promptly
2. Avoid strenuous activity
3. Monitor breathing patterns
4. Note any associated symptoms
**When to Seek Immediate Medical Attention:**
- Severe chest pain
- Difficulty breathing
- Pain spreading to arm, jaw, or back
- Dizziness or fainting
- These symptoms require immediate medical care
*Note: This is a demo response. For actual medical analysis, the MedGemma model needs to be loaded.*"""
else:
return f"""**DEMO MODE - Model not loaded**
Thank you for describing your symptoms. In demo mode, I can provide general guidance:
**General Recommendations:**
1. Keep track of when symptoms started
2. Note any factors that make symptoms better or worse
3. Stay hydrated and get adequate rest
4. Consider consulting a healthcare provider
**When to Seek Medical Attention:**
- Symptoms persist or worsen
- You develop new concerning symptoms
- You have underlying health conditions
- You're unsure about the severity
For a proper AI-powered analysis of your specific symptoms: "{symptoms_text[:100]}...", the MedGemma model would need to be successfully loaded.
*Note: This is a demo response. For actual medical analysis, the MedGemma model needs to be loaded.*"""
# Initialize the analyzer
analyzer = MedGemmaSymptomAnalyzer()
def analyze_symptoms_interface(symptoms, temperature):
"""Interface function for Gradio"""
if not symptoms.strip():
return "Please enter some symptoms to analyze."
# Add medical disclaimer
disclaimer = """⚠️ **MEDICAL DISCLAIMER**: This analysis is for educational purposes only and should not replace professional medical advice. Always consult with a healthcare provider for proper diagnosis and treatment.
"""
analysis = analyzer.analyze_symptoms(symptoms, temperature=temperature)
return disclaimer + analysis
# Create Gradio interface
with gr.Blocks(title="MedGemma Symptom Analyzer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏥 MedGemma Symptom Analyzer
This application uses Google's MedGemma model to provide preliminary analysis of medical symptoms.
Enter your symptoms below to get AI-powered insights.
**Important**: This tool is for educational purposes only and should not replace professional medical advice.
""")
with gr.Row():
with gr.Column(scale=2):
symptoms_input = gr.Textbox(
label="Describe your symptoms",
placeholder="Example: I have been experiencing headaches, fever, and fatigue for the past 3 days...",
lines=5,
max_lines=10
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Response Creativity (Temperature)",
info="Lower values = more focused, Higher values = more creative"
)
analyze_btn = gr.Button("🔍 Analyze Symptoms", variant="primary")
with gr.Column(scale=3):
output = gr.Textbox(
label="Medical Analysis",
lines=15,
max_lines=20,
interactive=False
)
# Example symptoms
gr.Markdown("### 📝 Example Symptoms to Try:")
with gr.Row():
example1 = gr.Button("Flu-like symptoms", size="sm")
example2 = gr.Button("Chest pain", size="sm")
example3 = gr.Button("Digestive issues", size="sm")
# Event handlers
analyze_btn.click(
analyze_symptoms_interface,
inputs=[symptoms_input, temperature_slider],
outputs=output
)
example1.click(
lambda: "I have been experiencing fever, body aches, headache, and fatigue for the past 2 days. I also have a slight cough.",
outputs=symptoms_input
)
example2.click(
lambda: "I'm experiencing chest pain that worsens when I take deep breaths. It started this morning and is accompanied by shortness of breath.",
outputs=symptoms_input
)
example3.click(
lambda: "I have been having stomach pain, nausea, and diarrhea for the past day. I also feel bloated and have lost my appetite.",
outputs=symptoms_input
)
gr.Markdown("""
### ⚠️ Important Notes:
- This is an AI tool for educational purposes only
- Always consult healthcare professionals for medical concerns
- Seek immediate medical attention for severe or emergency symptoms
- The AI may not always provide accurate medical information
""")
if __name__ == "__main__":
demo.launch()