harishvijayasarangan commited on
Commit
8229bc6
·
verified ·
1 Parent(s): 70cc854

Upload 4 files

Browse files
Files changed (4) hide show
  1. .dockerfile +18 -0
  2. main.py +113 -0
  3. model.onnx +3 -0
  4. requirements.txt +6 -0
.dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Copy requirements and model
7
+ COPY requirements.txt .
8
+ COPY model.onnx .
9
+ COPY main.py .
10
+
11
+ # Install dependencies
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Expose port 7860 (required for Hugging Face Spaces)
15
+ EXPOSE 7860
16
+
17
+ # Run the FastAPI app
18
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ from fastapi import FastAPI, File, UploadFile, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import onnxruntime
6
+ import numpy as np
7
+ from PIL import Image
8
+ import uvicorn
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
12
+ )
13
+ logger = logging.getLogger("dr-api")
14
+ app = FastAPI(
15
+ title="Diabetic Retinopathy Detection API",
16
+ description="API for detecting diabetic retinopathy from retinal images",
17
+ version="1.0.0"
18
+ )
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # frontend URL in production
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+ labels = {
27
+ 0: "No DR",
28
+ 1: "Mild",
29
+ 2: "Moderate",
30
+ 3: "Severe",
31
+ 4: "Proliferative DR",
32
+ }
33
+ try:
34
+ logger.info("Loading ONNX model...")
35
+ session = onnxruntime.InferenceSession('model.onnx')
36
+ logger.info("Model loaded successfully")
37
+ except Exception as e:
38
+ logger.error(f"Error loading model: {e}")
39
+ session = None
40
+ @app.get("/health")
41
+ async def health_check():
42
+ if session is None:
43
+ return {"status": "unhealthy", "message": "Model failed to load"}
44
+ return {"status": "healthy", "model_loaded": True}
45
+ def transform_image(image):
46
+ """Preprocess image for model inference"""
47
+ image = image.resize((224, 224))
48
+ img_array = np.array(image, dtype=np.float32) / 255.0
49
+ mean = np.array([0.5353, 0.3628, 0.2486], dtype=np.float32)
50
+ std = np.array([0.2126, 0.1586, 0.1401], dtype=np.float32)
51
+ img_array = (img_array - mean) / std
52
+ img_array = np.transpose(img_array, (2, 0, 1))
53
+ return np.expand_dims(img_array, axis=0).astype(np.float32)
54
+ @app.post("/predict")
55
+ async def predict(file: UploadFile = File(...)):
56
+ """
57
+ Predict diabetic retinopathy from retinal image
58
+
59
+ - **file**: Upload a retinal image file
60
+
61
+ Returns detailed classification for all DR grades and a binary classification
62
+ """
63
+ logger.info(f"Received image: {file.filename}, content-type: {file.content_type}")
64
+ if session is None:
65
+ raise HTTPException(status_code=503, detail="Model not available")
66
+ if not file.content_type.startswith("image/"):
67
+ raise HTTPException(status_code=400, detail="File provided is not an image")
68
+
69
+ try:
70
+ image_data = await file.read()
71
+ input_img = Image.open(io.BytesIO(image_data)).convert("RGB")
72
+ input_tensor = transform_image(input_img)
73
+ input_name = session.get_inputs()[0].name
74
+ output_name = session.get_outputs()[0].name
75
+
76
+ logger.info("Running inference")
77
+ prediction = session.run([output_name], {input_name: input_tensor})[0][0]
78
+ exp_preds = np.exp(prediction - np.max(prediction))
79
+ probabilities = exp_preds / exp_preds.sum()
80
+
81
+ # Format results
82
+ full_confidences = {labels[i]:float(f"{probabilities[i] * 100:.0f}") for i in labels}
83
+ #full_confidences = {labels[i]: int(probabilities[i] * 100) for i in labels}
84
+ #full_confidences = {labels[i]: f"{round(probabilities[i] * 100, 0)}" for i in labels}
85
+ #full_confidences = {labels[i]: float(probabilities[i]) for i in labels}
86
+
87
+ # Calculate binary classification
88
+ #severe_prob = (full_confidences["Severe"] +
89
+ # full_confidences["Moderate"] +
90
+ # full_confidences["Proliferative DR"])
91
+
92
+ # binary_result = {
93
+ # "No DR": full_confidences["No DR"],
94
+ # "DR Detected": severe_prob
95
+ # }
96
+
97
+ highest_class = max(full_confidences.items(), key=lambda x: x[1])[0]
98
+ logger.info(f"Prediction complete: highest probability class = {highest_class}")
99
+
100
+ # Return both full and binary classifications
101
+ return {
102
+ "detailed_classification": full_confidences,
103
+ # "binary_classification": binary_result,
104
+ "highest_probability_class": highest_class
105
+ }
106
+
107
+ except Exception as e:
108
+ logger.error(f"Error processing image: {e}", exc_info=True)
109
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
110
+
111
+ # Run the server
112
+ if __name__ == "__main__":
113
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2df12f77e5a9240ad729d61a4e63c0304cb232bc99f47c4a166c2330efa0780
3
+ size 28227960
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ onnxruntime
4
+ numpy
5
+ pillow
6
+ python-multipart