import gradio as gr
from PIL import Image
from dataclasses import dataclass
import random
from transformers import pipeline
from huggingface_hub import InferenceClient, login
import os

@dataclass
class PatientMetadata:
    age: int
    smoking_status: str
    family_history: bool
    menopause_status: str
    previous_mammogram: bool
    breast_density: str
    hormone_therapy: bool

class SimplifiedBreastAnalyzer:
    def __init__(self, hf_token: str):
        """Initialize the analyzer with models."""
        print("Initializing system...")
        
        # Login to Hugging Face
        login(token=hf_token)
        
        # Initialize vision pipelines for tumor detection and size classification
        self.tumor_classifier = pipeline(
            "image-classification",
            model="SIATCN/vit_tumor_classifier",
            device="cpu"
        )
        
        self.size_classifier = pipeline(
            "image-classification",
            model="SIATCN/vit_tumor_radius_detection_finetuned",
            device="cpu"
        )
        
        # Initialize Mistral client for report generation
        self.report_generator = InferenceClient(
            model="mistralai/Mixtral-8x7B-Instruct-v0.1",
            token=hf_token
        )
        
        print("Initialization complete!")

    def _generate_synthetic_metadata(self) -> PatientMetadata:
        """Generate realistic patient metadata for breast cancer screening."""
        age = random.randint(40, 75)
        smoking_status = random.choice(["Never Smoker", "Former Smoker", "Current Smoker"])
        family_history = random.choice([True, False])
        menopause_status = "Post-menopausal" if age > 50 else "Pre-menopausal"
        previous_mammogram = random.choice([True, False])
        breast_density = random.choice([
            "A: Almost entirely fatty", 
            "B: Scattered fibroglandular",
            "C: Heterogeneously dense",
            "D: Extremely dense"
        ])
        hormone_therapy = random.choice([True, False])

        return PatientMetadata(
            age=age,
            smoking_status=smoking_status,
            family_history=family_history,
            menopause_status=menopause_status,
            previous_mammogram=previous_mammogram,
            breast_density=breast_density,
            hormone_therapy=hormone_therapy
        )

    def _process_image(self, image: Image.Image) -> Image.Image:
        """Process input image for model consumption."""
        if image.mode != 'RGB':
            image = image.convert('RGB')
        return image.resize((224, 224))

    def _generate_medical_report(self, has_tumor: bool, tumor_size: str, metadata: PatientMetadata) -> str:
        """Generate a medical report using Mistral."""
        prompt = f"""<s>[INST] Generate a detailed medical report for this breast imaging scan:

Scan Results:
- Finding: {'Abnormal area detected' if has_tumor else 'No abnormalities detected'}
{f'- Size of abnormal area: {tumor_size} cm' if has_tumor else ''}

Patient Information:
- Age: {metadata.age} years
- Risk factors: {', '.join([
    'family history of breast cancer' if metadata.family_history else '',
    f'{metadata.smoking_status.lower()}',
    'currently on hormone therapy' if metadata.hormone_therapy else ''
    ]).strip(', ')}
- Breast density: {metadata.breast_density}
- Previous mammogram: {'Yes' if metadata.previous_mammogram else 'No'}
- Menopausal status: {metadata.menopause_status}

Please provide:
1. A clear interpretation of the findings
2. A specific recommendation for next steps based on the findings and risk factors
3. Recommended follow-up timeline [/INST]</s>"""

        # Generate response using Mistral
        response = self.report_generator.text_generation(
            prompt,
            max_new_tokens=512,
            temperature=0.3,
            top_p=0.9,
            repetition_penalty=1.1,
            do_sample=True,
            seed=42
        )
            
        return f"FINDINGS AND RECOMMENDATIONS:\n{response}"

    def analyze(self, image: Image.Image) -> str:
        """Main analysis pipeline."""
        try:
            processed_image = self._process_image(image)
            metadata = self._generate_synthetic_metadata()

            # Detect tumor
            tumor_result = self.tumor_classifier(processed_image)
            has_tumor = tumor_result[0]['label'] == 'tumor'

            # Measure size if tumor detected
            size_result = self.size_classifier(processed_image)
            tumor_size = size_result[0]['label'].replace('tumor-', '')

            # Generate report
            report = self._generate_medical_report(has_tumor, tumor_size, metadata)

            return f"""SCAN RESULTS:
{'⚠️ Abnormal area detected' if has_tumor else '✓ No abnormalities detected'}
{f'Size of abnormal area: {tumor_size} cm' if has_tumor else ''}

PATIENT INFORMATION:
• Age: {metadata.age} years
• Risk Factors: {', '.join([
    'family history of breast cancer' if metadata.family_history else '',
    metadata.smoking_status.lower(),
    'currently on hormone therapy' if metadata.hormone_therapy else ''
    ]).strip(', ')}
• Breast Density: {metadata.breast_density}
• Previous Mammogram: {'Yes' if metadata.previous_mammogram else 'No'}
• Menopausal Status: {metadata.menopause_status}

{report}"""
        except Exception as e:
            import traceback
            return f"Error during analysis: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"

def create_interface(hf_token: str) -> gr.Interface:
    """Create the Gradio interface."""
    analyzer = SimplifiedBreastAnalyzer(hf_token)
    
    interface = gr.Interface(
        fn=analyzer.analyze,
        inputs=[
            gr.Image(type="pil", label="Upload Breast Image for Analysis")
        ],
        outputs=[
            gr.Textbox(label="Analysis Results", lines=20)
        ],
        title="Breast Imaging Analysis System",
        description="""Upload a breast image for comprehensive analysis. The system will:
        1. Detect the presence of tumors
        2. Classify tumor size if present
        3. Generate a detailed medical report with recommendations""",
    )
    
    return interface

if __name__ == "__main__":
    print("Starting application...")
    # Load HuggingFace token from secrets
    HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
    if not HF_TOKEN:
        raise ValueError("Please set HUGGINGFACE_TOKEN environment variable")
        
    interface = create_interface(HF_TOKEN)
    # Modified launch parameters for Spaces
    interface.launch(
        debug=True,
        server_name="0.0.0.0",  # Required for Spaces
        server_port=7860,       # Standard port for Spaces
        share=False             # Disable sharing as it's not needed on Spaces
    )