Spaces:
Runtime error
Runtime error
import onnxruntime | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
# Label mapping | |
labels = { | |
0: "No DR", | |
1: "Mild", | |
2: "Moderate", | |
3: "Severe", | |
4: "Proliferative DR", | |
} | |
# Load ONNX model | |
session = onnxruntime.InferenceSession('model.onnx') | |
# Preprocess image to match model input | |
def transform_image(image): | |
image = image.resize((224, 224)) | |
img_array = np.array(image, dtype=np.float32) / 255.0 | |
mean = np.array([0.5353, 0.3628, 0.2486], dtype=np.float32) | |
std = np.array([0.2126, 0.1586, 0.1401], dtype=np.float32) | |
img_array = (img_array - mean) / std | |
img_array = np.transpose(img_array, (2, 0, 1)) # [C, H, W] | |
return np.expand_dims(img_array, axis=0).astype(np.float32) # [1, C, H, W] | |
# Prediction function | |
def predict(input_img): | |
input_tensor = transform_image(input_img) | |
input_name = session.get_inputs()[0].name | |
output_name = session.get_outputs()[0].name | |
prediction = session.run([output_name], {input_name: input_tensor})[0][0] | |
# Convert logits to softmax probabilities | |
exp_preds = np.exp(prediction - np.max(prediction)) | |
probabilities = exp_preds / exp_preds.sum() | |
# Format class-wise output | |
confidences = {labels[i]: float(probabilities[i]) for i in labels} | |
top_class = max(confidences, key=confidences.get) | |
# Return as (label, {label: score}) -> what gr.Label(num_top_classes=5) expects | |
return top_class, confidences | |
# Create the Gradio Interface | |
dr_app = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=5), | |
title="Diabetic Retinopathy Detection", | |
description="Upload a retina image to detect the stage of Diabetic Retinopathy.", | |
examples=[ | |
"sample/1.jpeg", | |
"sample/2.jpeg", | |
"sample/3.jpeg", | |
"sample/4.jpeg", | |
], | |
allow_flagging="never", | |
analytics_enabled=False | |
) | |
# Launch with `share=True` for Hugging Face Spaces | |
if __name__ == "__main__": | |
dr_app.launch(server_name="0.0.0.0", server_port=8080, share=True) | |