DR-server / main.py
harishvijayasarangan's picture
Update main.py
68249e6 verified
import io
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import onnxruntime
import numpy as np
from PIL import Image
import uvicorn
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("dr-api")
app = FastAPI(
title="Diabetic Retinopathy Detection API",
description="API for detecting diabetic retinopathy from retinal images",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"[https://diabetes-detection-zeta.vercel.app](https://diabetes-detection-zeta.vercel.app)",
"[https://diabetes-detection-harishvijayasarangank-gmailcoms-projects.vercel.app](https://diabetes-detection-harishvijayasarangank-gmailcoms-projects.vercel.app)"
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
labels = {
0: "No DR",
1: "Mild",
2: "Moderate",
3: "Severe",
4: "Proliferative DR",
}
try:
logger.info("Loading ONNX model...")
session = onnxruntime.InferenceSession('model.onnx')
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
session = None
@app.get("/health")
async def health_check():
if session is None:
return {"status": "unhealthy", "message": "Model failed to load"}
return {"status": "healthy", "model_loaded": True}
def transform_image(image):
"""Preprocess image for model inference"""
image = image.resize((224, 224))
img_array = np.array(image, dtype=np.float32) / 255.0
mean = np.array([0.5353, 0.3628, 0.2486], dtype=np.float32)
std = np.array([0.2126, 0.1586, 0.1401], dtype=np.float32)
img_array = (img_array - mean) / std
img_array = np.transpose(img_array, (2, 0, 1))
return np.expand_dims(img_array, axis=0).astype(np.float32)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Predict diabetic retinopathy from retinal image
- **file**: Upload a retinal image file
Returns detailed classification for all DR grades and a binary classification
"""
logger.info(f"Received image: {file.filename}, content-type: {file.content_type}")
if session is None:
raise HTTPException(status_code=503, detail="Model not available")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File provided is not an image")
try:
image_data = await file.read()
input_img = Image.open(io.BytesIO(image_data)).convert("RGB")
input_tensor = transform_image(input_img)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
logger.info("Running inference")
prediction = session.run([output_name], {input_name: input_tensor})[0][0]
exp_preds = np.exp(prediction - np.max(prediction))
probabilities = exp_preds / exp_preds.sum()
# Format results
full_confidences = {labels[i]:float(f"{probabilities[i] * 100:.0f}") for i in labels}
#full_confidences = {labels[i]: int(probabilities[i] * 100) for i in labels}
#full_confidences = {labels[i]: f"{round(probabilities[i] * 100, 0)}" for i in labels}
#full_confidences = {labels[i]: float(probabilities[i]) for i in labels}
# Calculate binary classification
#severe_prob = (full_confidences["Severe"] +
# full_confidences["Moderate"] +
# full_confidences["Proliferative DR"])
# binary_result = {
# "No DR": full_confidences["No DR"],
# "DR Detected": severe_prob
# }
highest_class = max(full_confidences.items(), key=lambda x: x[1])[0]
logger.info(f"Prediction complete: highest probability class = {highest_class}")
# Return both full and binary classifications
return {
"detailed_classification": full_confidences,
# "binary_classification": binary_result,
"highest_probability_class": highest_class
}
except Exception as e:
logger.error(f"Error processing image: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
# Run the server
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)