Spaces:
Sleeping
Sleeping
# gradio_app.py | |
import gradio as gr | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
import numpy as np | |
import cv2 | |
# --- Models --- | |
class EnhancedCNN_MRI(nn.Module): | |
def __init__(self): | |
super(EnhancedCNN_MRI, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, 3, padding=1) | |
self.bn1 = nn.BatchNorm2d(32) | |
self.pool1 = nn.MaxPool2d(2) | |
self.conv2 = nn.Conv2d(32, 64, 3, padding=1) | |
self.bn2 = nn.BatchNorm2d(64) | |
self.pool2 = nn.MaxPool2d(2) | |
self.conv3 = nn.Conv2d(64, 128, 3, padding=1) | |
self.bn3 = nn.BatchNorm2d(128) | |
self.pool3 = nn.MaxPool2d(2) | |
self.conv4 = nn.Conv2d(128, 256, 3, padding=1) | |
self.bn4 = nn.BatchNorm2d(256) | |
self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.fc1 = nn.Linear(256, 256) | |
self.dropout = nn.Dropout(0.5) | |
self.fc2 = nn.Linear(256, 1) | |
def forward(self, x): | |
x = self.pool1(F.relu(self.bn1(self.conv1(x)))) | |
x = self.pool2(F.relu(self.bn2(self.conv2(x)))) | |
x = self.pool3(F.relu(self.bn3(self.conv3(x)))) | |
x = self.global_pool(F.relu(self.bn4(self.conv4(x)))) | |
x = torch.flatten(x, 1) | |
x = self.dropout(F.relu(self.fc1(x))) | |
return self.fc2(x) | |
class EnhancedCNN_CT(nn.Module): | |
def __init__(self): | |
super(EnhancedCNN_CT, self).__init__() | |
self.conv1 = nn.Conv2d(3, 32, 3, padding=1) | |
self.bn1 = nn.BatchNorm2d(32) | |
self.pool1 = nn.MaxPool2d(2) | |
self.conv2 = nn.Conv2d(32, 64, 3, padding=1) | |
self.bn2 = nn.BatchNorm2d(64) | |
self.pool2 = nn.MaxPool2d(2) | |
self.conv3 = nn.Conv2d(64, 128, 3, padding=1) | |
self.bn3 = nn.BatchNorm2d(128) | |
self.pool3 = nn.MaxPool2d(2) | |
self.conv4 = nn.Conv2d(128, 256, 3, padding=1) | |
self.bn4 = nn.BatchNorm2d(256) | |
self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.fc1 = nn.Linear(256, 256) | |
self.dropout = nn.Dropout(0.5) | |
self.fc2 = nn.Linear(256, 1) | |
def forward(self, x): | |
x = self.pool1(F.relu(self.bn1(self.conv1(x)))) | |
x = self.pool2(F.relu(self.bn2(self.conv2(x)))) | |
x = self.pool3(F.relu(self.bn3(self.conv3(x)))) | |
x = self.global_pool(F.relu(self.bn4(self.conv4(x)))) | |
x = torch.flatten(x, 1) | |
x = self.dropout(F.relu(self.fc1(x))) | |
return self.fc2(x) | |
class Sub_Class_CNNModel_CT(nn.Module): | |
def __init__(self, num_classes=2): | |
super(Sub_Class_CNNModel_CT, self).__init__() | |
self.features = nn.Sequential( | |
nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
nn.BatchNorm2d(32), | |
nn.ReLU(), | |
nn.MaxPool2d(2), | |
nn.Dropout(0.25), | |
nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.MaxPool2d(2), | |
nn.Dropout(0.25), | |
nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(), | |
nn.MaxPool2d(2), | |
nn.Dropout(0.25) | |
) | |
self.classifier = nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(128 * 28 * 28, 512), | |
nn.BatchNorm1d(512), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(512, num_classes) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.classifier(x) | |
return torch.softmax(x, dim=1) | |
# --- Preprocessing --- | |
def preprocess_mri(img): | |
img = img.convert("L") | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
return transform(img).unsqueeze(0) | |
def preprocess_ct(img): | |
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
resized = cv2.resize(img_cv, (224, 224)) | |
img_pil = Image.fromarray(cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)) | |
transform = transforms.Compose([transforms.ToTensor()]) | |
return transform(img_pil).unsqueeze(0) | |
def preprocess_sub_ct(img): | |
img = img.convert("RGB") | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
return transform(img).unsqueeze(0) | |
# --- Inference Functions --- | |
def classify_mri(image): | |
model = EnhancedCNN_MRI() | |
model.load_state_dict(torch.load('MRI/best_model.pth', map_location='cpu')) | |
model.eval() | |
tensor = preprocess_mri(image) | |
with torch.no_grad(): | |
output = model(tensor) | |
pred = torch.sigmoid(output).item() | |
return ("Stroke", float(pred)) if pred >= 0.5 else ("Normal", 1 - float(pred)) | |
def classify_ct(image): | |
model = EnhancedCNN_CT() | |
model.load_state_dict(torch.load('CT/best_model_CT.pth', map_location='cpu')) | |
model.eval() | |
tensor = preprocess_ct(image) | |
with torch.no_grad(): | |
output = model(tensor) | |
pred = torch.sigmoid(output).item() | |
if pred < 0.5: | |
return ("Normal", 1 - float(pred)) | |
sub_model = Sub_Class_CNNModel_CT() | |
sub_model.load_state_dict(torch.load('CT/cnn_model_sub_class.pth', map_location='cpu')) | |
sub_model.eval() | |
tensor_sub = preprocess_sub_ct(image) | |
with torch.no_grad(): | |
sub_output = sub_model(tensor_sub) | |
sub_pred = torch.argmax(sub_output, dim=1).item() | |
sub_conf = sub_output[0][sub_pred].item() | |
sub_class_names = ['hemorrhagic', 'ischaemic'] | |
return (f"Stroke - {sub_class_names[sub_pred]}", float(sub_conf)) | |
# --- Gradio Interface --- | |
mri_ui = gr.Interface( | |
fn=classify_mri, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Label(label="Prediction"), gr.Number(label="Confidence")], | |
title="🧠 MRI Stroke Classifier" | |
) | |
ct_ui = gr.Interface( | |
fn=classify_ct, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Label(label="Prediction"), gr.Number(label="Confidence")], | |
title="🧠 CT Stroke + Subtype Classifier" | |
) | |
demo = gr.TabbedInterface([mri_ui, ct_ui], ["MRI Classifier", "CT Classifier"]) | |
demo.launch() | |