File size: 4,579 Bytes
7742f09
9c37d23
b81dd02
9c37d23
 
 
7742f09
 
 
 
3002787
7742f09
 
9c37d23
 
 
7742f09
 
 
3002787
7742f09
 
9c37d23
7818ba9
3002787
7818ba9
 
 
 
 
 
3002787
7818ba9
 
 
7742f09
3002787
b81dd02
7818ba9
3002787
 
b81dd02
 
3002787
7742f09
 
3002787
 
 
7742f09
 
 
3002787
 
 
 
7742f09
 
 
3002787
 
 
 
 
9c37d23
 
 
3002787
 
 
 
 
9c37d23
 
 
3002787
9c37d23
 
 
 
3002787
9c37d23
3002787
7818ba9
3002787
9c37d23
 
3002787
 
 
9c37d23
 
7742f09
9c37d23
 
3002787
 
 
 
 
 
 
 
 
 
 
 
 
7818ba9
9c37d23
 
3002787
 
 
9c37d23
 
 
3002787
9c37d23
 
 
 
7742f09
 
3002787
9c37d23
7742f09
3002787
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from fastapi import FastAPI, File, UploadFile, HTTPException
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from PIL import Image
import io
import base64
import os
import logging
from huggingface_hub import login

# Enable logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

MODEL_NAME = "mervinpraison/Llama-3.2-11B-Vision-Radiology-mini"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Set Hugging Face Cache Directory
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"

# Ensure Hugging Face Token is Set
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
    logger.error("Hugging Face token not found! Set HUGGINGFACE_TOKEN in the environment.")
    raise RuntimeError("Hugging Face token missing. Set it in your environment.")

# Login to Hugging Face
try:
    login(HF_TOKEN)
except Exception as e:
    logger.error(f"Failed to authenticate Hugging Face token: {e}")
    raise RuntimeError("Authentication with Hugging Face failed.")

# Configure Quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Change to load_in_8bit=True if 4-bit fails
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load Model
try:
    logger.info("Loading model and processor...")
    processor = AutoProcessor.from_pretrained(
        MODEL_NAME, cache_dir="/tmp/huggingface", force_download=False
    )
    model = AutoModelForImageTextToText.from_pretrained(
        MODEL_NAME, quantization_config=quantization_config, cache_dir="/tmp/huggingface"
    ).to(DEVICE)

    torch.backends.cuda.matmul.allow_tf32 = True  # Optimize CUDA
    torch.cuda.empty_cache()
    
    logger.info("Model loaded successfully.")
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise RuntimeError("Model loading failed. Check model accessibility.")

# Allowed Formats
ALLOWED_FORMATS = {"jpeg", "jpg", "png", "bmp", "tiff"}

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    try:
        ext = file.filename.split(".")[-1].lower()
        if ext not in ALLOWED_FORMATS:
            raise HTTPException(status_code=400, detail=f"Invalid file format: {ext}. Upload an image file.")

        # Read Image
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        # Convert Image to Base64
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")

        # Validation Step
        validation_prompt = "Is this a medical X-ray or CT scan? Answer only 'yes' or 'no'."
        validation_inputs = processor(
            text=validation_prompt, images=image, return_tensors="pt"
        ).to(DEVICE)

        with torch.no_grad():
            validation_output = model.generate(
                **validation_inputs, max_new_tokens=10, temperature=0.1, top_p=0.7, top_k=50
            )

        validation_result = processor.batch_decode(validation_output, skip_special_tokens=True)[0].strip().lower()
        logger.info(f"Validation result: {validation_result}")

        if "yes" not in validation_result:
            raise HTTPException(status_code=400, detail="Uploaded image is not an X-ray or CT scan.")

        # Analysis Step
        analysis_prompt = """Analyze this X-ray image and provide a detailed medical report:
        
        Type of X-ray:
        Key Findings:
        • [Findings]
        Potential Conditions:
        • [Possible Diagnoses]
        Recommendations:
        • [Follow-up Actions]
        """
        analysis_inputs = processor(text=analysis_prompt, images=image, return_tensors="pt").to(DEVICE)

        with torch.no_grad():
            analysis_output = model.generate(
                **analysis_inputs, max_new_tokens=512, temperature=0.7, top_p=0.7, top_k=50
            )

        analysis_content = processor.batch_decode(analysis_output, skip_special_tokens=True)[0]
        cleaned_analysis = (
            analysis_content.replace("**", "").replace("*", "•").replace("_", "").strip()
        )

        return {"analysis": cleaned_analysis}

    except HTTPException as http_err:
        logger.error(f"Validation error: {http_err.detail}")
        raise http_err
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")