File size: 4,538 Bytes
8229bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68249e6
8229bc6
a5f98f0
 
 
 
8229bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import io
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import onnxruntime
import numpy as np
from PIL import Image
import uvicorn
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("dr-api")
app = FastAPI(
    title="Diabetic Retinopathy Detection API",
    description="API for detecting diabetic retinopathy from retinal images",
    version="1.0.0"
)
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "[https://diabetes-detection-zeta.vercel.app](https://diabetes-detection-zeta.vercel.app)",
        "[https://diabetes-detection-harishvijayasarangank-gmailcoms-projects.vercel.app](https://diabetes-detection-harishvijayasarangank-gmailcoms-projects.vercel.app)"
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
labels = {
    0: "No DR",
    1: "Mild",
    2: "Moderate",
    3: "Severe",
    4: "Proliferative DR",
}
try:
    logger.info("Loading ONNX model...")
    session = onnxruntime.InferenceSession('model.onnx')
    logger.info("Model loaded successfully")
except Exception as e:
    logger.error(f"Error loading model: {e}")
    session = None  
@app.get("/health")
async def health_check():
    if session is None:
        return {"status": "unhealthy", "message": "Model failed to load"}
    return {"status": "healthy", "model_loaded": True}
def transform_image(image):
    """Preprocess image for model inference"""
    image = image.resize((224, 224))
    img_array = np.array(image, dtype=np.float32) / 255.0
    mean = np.array([0.5353, 0.3628, 0.2486], dtype=np.float32)
    std = np.array([0.2126, 0.1586, 0.1401], dtype=np.float32)
    img_array = (img_array - mean) / std
    img_array = np.transpose(img_array, (2, 0, 1))
    return np.expand_dims(img_array, axis=0).astype(np.float32)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """
    Predict diabetic retinopathy from retinal image
    
    - **file**: Upload a retinal image file
    
    Returns detailed classification for all DR grades and a binary classification
    """
    logger.info(f"Received image: {file.filename}, content-type: {file.content_type}")
    if session is None:
        raise HTTPException(status_code=503, detail="Model not available")
    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="File provided is not an image")
    
    try:
        image_data = await file.read()
        input_img = Image.open(io.BytesIO(image_data)).convert("RGB")
        input_tensor = transform_image(input_img)
        input_name = session.get_inputs()[0].name
        output_name = session.get_outputs()[0].name
        
        logger.info("Running inference")
        prediction = session.run([output_name], {input_name: input_tensor})[0][0]
        exp_preds = np.exp(prediction - np.max(prediction))
        probabilities = exp_preds / exp_preds.sum()
        
        # Format results
        full_confidences = {labels[i]:float(f"{probabilities[i] * 100:.0f}") for i in labels}
        #full_confidences = {labels[i]: int(probabilities[i] * 100) for i in labels}
        #full_confidences = {labels[i]: f"{round(probabilities[i] * 100, 0)}" for i in labels}
        #full_confidences = {labels[i]: float(probabilities[i]) for i in labels}
        
        # Calculate binary classification
        #severe_prob = (full_confidences["Severe"] + 
             #         full_confidences["Moderate"] + 
                #      full_confidences["Proliferative DR"])
        
       # binary_result = {
          #  "No DR": full_confidences["No DR"],
           # "DR Detected": severe_prob
       # }
        
        highest_class = max(full_confidences.items(), key=lambda x: x[1])[0]
        logger.info(f"Prediction complete: highest probability class = {highest_class}")
        
        # Return both full and binary classifications
        return {
            "detailed_classification": full_confidences,
          #  "binary_classification": binary_result,
            "highest_probability_class": highest_class
        }
        
    except Exception as e:
        logger.error(f"Error processing image: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")

# Run the server
if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)