Ahmed-El-Sharkawy's picture
Rename app-ver-2.py to app.py
ced0511 verified
# gradio_app.py
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
import cv2
# --- Models ---
class EnhancedCNN_MRI(nn.Module):
def __init__(self):
super(EnhancedCNN_MRI, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
self.bn4 = nn.BatchNorm2d(256)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(256, 256)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(256, 1)
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.global_pool(F.relu(self.bn4(self.conv4(x))))
x = torch.flatten(x, 1)
x = self.dropout(F.relu(self.fc1(x)))
return self.fc2(x)
class EnhancedCNN_CT(nn.Module):
def __init__(self):
super(EnhancedCNN_CT, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
self.bn4 = nn.BatchNorm2d(256)
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(256, 256)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(256, 1)
def forward(self, x):
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.global_pool(F.relu(self.bn4(self.conv4(x))))
x = torch.flatten(x, 1)
x = self.dropout(F.relu(self.fc1(x)))
return self.fc2(x)
class Sub_Class_CNNModel_CT(nn.Module):
def __init__(self, num_classes=2):
super(Sub_Class_CNNModel_CT, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 28 * 28, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return torch.softmax(x, dim=1)
# --- Preprocessing ---
def preprocess_mri(img):
img = img.convert("L")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
return transform(img).unsqueeze(0)
def preprocess_ct(img):
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
resized = cv2.resize(img_cv, (224, 224))
img_pil = Image.fromarray(cv2.cvtColor(resized, cv2.COLOR_BGR2RGB))
transform = transforms.Compose([transforms.ToTensor()])
return transform(img_pil).unsqueeze(0)
def preprocess_sub_ct(img):
img = img.convert("RGB")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform(img).unsqueeze(0)
# --- Inference Functions ---
def classify_mri(image):
model = EnhancedCNN_MRI()
model.load_state_dict(torch.load('MRI/best_model.pth', map_location='cpu'))
model.eval()
tensor = preprocess_mri(image)
with torch.no_grad():
output = model(tensor)
pred = torch.sigmoid(output).item()
return ("Stroke", float(pred)) if pred >= 0.5 else ("Normal", 1 - float(pred))
def classify_ct(image):
model = EnhancedCNN_CT()
model.load_state_dict(torch.load('CT/best_model_CT.pth', map_location='cpu'))
model.eval()
tensor = preprocess_ct(image)
with torch.no_grad():
output = model(tensor)
pred = torch.sigmoid(output).item()
if pred < 0.5:
return ("Normal", 1 - float(pred))
sub_model = Sub_Class_CNNModel_CT()
sub_model.load_state_dict(torch.load('CT/cnn_model_sub_class.pth', map_location='cpu'))
sub_model.eval()
tensor_sub = preprocess_sub_ct(image)
with torch.no_grad():
sub_output = sub_model(tensor_sub)
sub_pred = torch.argmax(sub_output, dim=1).item()
sub_conf = sub_output[0][sub_pred].item()
sub_class_names = ['hemorrhagic', 'ischaemic']
return (f"Stroke - {sub_class_names[sub_pred]}", float(sub_conf))
# --- Gradio Interface ---
mri_ui = gr.Interface(
fn=classify_mri,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(label="Prediction"), gr.Number(label="Confidence")],
title="🧠 MRI Stroke Classifier"
)
ct_ui = gr.Interface(
fn=classify_ct,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(label="Prediction"), gr.Number(label="Confidence")],
title="🧠 CT Stroke + Subtype Classifier"
)
demo = gr.TabbedInterface([mri_ui, ct_ui], ["MRI Classifier", "CT Classifier"])
demo.launch()