Spaces:
Running
Running
File size: 4,538 Bytes
8229bc6 68249e6 8229bc6 a5f98f0 8229bc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import io
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import onnxruntime
import numpy as np
from PIL import Image
import uvicorn
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("dr-api")
app = FastAPI(
title="Diabetic Retinopathy Detection API",
description="API for detecting diabetic retinopathy from retinal images",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"[https://diabetes-detection-zeta.vercel.app](https://diabetes-detection-zeta.vercel.app)",
"[https://diabetes-detection-harishvijayasarangank-gmailcoms-projects.vercel.app](https://diabetes-detection-harishvijayasarangank-gmailcoms-projects.vercel.app)"
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
labels = {
0: "No DR",
1: "Mild",
2: "Moderate",
3: "Severe",
4: "Proliferative DR",
}
try:
logger.info("Loading ONNX model...")
session = onnxruntime.InferenceSession('model.onnx')
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
session = None
@app.get("/health")
async def health_check():
if session is None:
return {"status": "unhealthy", "message": "Model failed to load"}
return {"status": "healthy", "model_loaded": True}
def transform_image(image):
"""Preprocess image for model inference"""
image = image.resize((224, 224))
img_array = np.array(image, dtype=np.float32) / 255.0
mean = np.array([0.5353, 0.3628, 0.2486], dtype=np.float32)
std = np.array([0.2126, 0.1586, 0.1401], dtype=np.float32)
img_array = (img_array - mean) / std
img_array = np.transpose(img_array, (2, 0, 1))
return np.expand_dims(img_array, axis=0).astype(np.float32)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Predict diabetic retinopathy from retinal image
- **file**: Upload a retinal image file
Returns detailed classification for all DR grades and a binary classification
"""
logger.info(f"Received image: {file.filename}, content-type: {file.content_type}")
if session is None:
raise HTTPException(status_code=503, detail="Model not available")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File provided is not an image")
try:
image_data = await file.read()
input_img = Image.open(io.BytesIO(image_data)).convert("RGB")
input_tensor = transform_image(input_img)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
logger.info("Running inference")
prediction = session.run([output_name], {input_name: input_tensor})[0][0]
exp_preds = np.exp(prediction - np.max(prediction))
probabilities = exp_preds / exp_preds.sum()
# Format results
full_confidences = {labels[i]:float(f"{probabilities[i] * 100:.0f}") for i in labels}
#full_confidences = {labels[i]: int(probabilities[i] * 100) for i in labels}
#full_confidences = {labels[i]: f"{round(probabilities[i] * 100, 0)}" for i in labels}
#full_confidences = {labels[i]: float(probabilities[i]) for i in labels}
# Calculate binary classification
#severe_prob = (full_confidences["Severe"] +
# full_confidences["Moderate"] +
# full_confidences["Proliferative DR"])
# binary_result = {
# "No DR": full_confidences["No DR"],
# "DR Detected": severe_prob
# }
highest_class = max(full_confidences.items(), key=lambda x: x[1])[0]
logger.info(f"Prediction complete: highest probability class = {highest_class}")
# Return both full and binary classifications
return {
"detailed_classification": full_confidences,
# "binary_classification": binary_result,
"highest_probability_class": highest_class
}
except Exception as e:
logger.error(f"Error processing image: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
# Run the server
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True) |