# gradio_app.py import gradio as gr from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import numpy as np import cv2 # --- Models --- class EnhancedCNN_MRI(nn.Module): def __init__(self): super(EnhancedCNN_MRI, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(2) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(2) self.conv3 = nn.Conv2d(64, 128, 3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.pool3 = nn.MaxPool2d(2) self.conv4 = nn.Conv2d(128, 256, 3, padding=1) self.bn4 = nn.BatchNorm2d(256) self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc1 = nn.Linear(256, 256) self.dropout = nn.Dropout(0.5) self.fc2 = nn.Linear(256, 1) def forward(self, x): x = self.pool1(F.relu(self.bn1(self.conv1(x)))) x = self.pool2(F.relu(self.bn2(self.conv2(x)))) x = self.pool3(F.relu(self.bn3(self.conv3(x)))) x = self.global_pool(F.relu(self.bn4(self.conv4(x)))) x = torch.flatten(x, 1) x = self.dropout(F.relu(self.fc1(x))) return self.fc2(x) class EnhancedCNN_CT(nn.Module): def __init__(self): super(EnhancedCNN_CT, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(2) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(2) self.conv3 = nn.Conv2d(64, 128, 3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.pool3 = nn.MaxPool2d(2) self.conv4 = nn.Conv2d(128, 256, 3, padding=1) self.bn4 = nn.BatchNorm2d(256) self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc1 = nn.Linear(256, 256) self.dropout = nn.Dropout(0.5) self.fc2 = nn.Linear(256, 1) def forward(self, x): x = self.pool1(F.relu(self.bn1(self.conv1(x)))) x = self.pool2(F.relu(self.bn2(self.conv2(x)))) x = self.pool3(F.relu(self.bn3(self.conv3(x)))) x = self.global_pool(F.relu(self.bn4(self.conv4(x)))) x = torch.flatten(x, 1) x = self.dropout(F.relu(self.fc1(x))) return self.fc2(x) class Sub_Class_CNNModel_CT(nn.Module): def __init__(self, num_classes=2): super(Sub_Class_CNNModel_CT, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout(0.25), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout(0.25), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout(0.25) ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(128 * 28 * 28, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes) ) def forward(self, x): x = self.features(x) x = self.classifier(x) return torch.softmax(x, dim=1) # --- Preprocessing --- def preprocess_mri(img): img = img.convert("L") transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) return transform(img).unsqueeze(0) def preprocess_ct(img): img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) resized = cv2.resize(img_cv, (224, 224)) img_pil = Image.fromarray(cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)) transform = transforms.Compose([transforms.ToTensor()]) return transform(img_pil).unsqueeze(0) def preprocess_sub_ct(img): img = img.convert("RGB") transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return transform(img).unsqueeze(0) # --- Inference Functions --- def classify_mri(image): model = EnhancedCNN_MRI() model.load_state_dict(torch.load('MRI/best_model.pth', map_location='cpu')) model.eval() tensor = preprocess_mri(image) with torch.no_grad(): output = model(tensor) pred = torch.sigmoid(output).item() return ("Stroke", float(pred)) if pred >= 0.5 else ("Normal", 1 - float(pred)) def classify_ct(image): model = EnhancedCNN_CT() model.load_state_dict(torch.load('CT/best_model_CT.pth', map_location='cpu')) model.eval() tensor = preprocess_ct(image) with torch.no_grad(): output = model(tensor) pred = torch.sigmoid(output).item() if pred < 0.5: return ("Normal", 1 - float(pred)) sub_model = Sub_Class_CNNModel_CT() sub_model.load_state_dict(torch.load('CT/cnn_model_sub_class.pth', map_location='cpu')) sub_model.eval() tensor_sub = preprocess_sub_ct(image) with torch.no_grad(): sub_output = sub_model(tensor_sub) sub_pred = torch.argmax(sub_output, dim=1).item() sub_conf = sub_output[0][sub_pred].item() sub_class_names = ['hemorrhagic', 'ischaemic'] return (f"Stroke - {sub_class_names[sub_pred]}", float(sub_conf)) # --- Gradio Interface --- mri_ui = gr.Interface( fn=classify_mri, inputs=gr.Image(type="pil"), outputs=[gr.Label(label="Prediction"), gr.Number(label="Confidence")], title="🧠 MRI Stroke Classifier" ) ct_ui = gr.Interface( fn=classify_ct, inputs=gr.Image(type="pil"), outputs=[gr.Label(label="Prediction"), gr.Number(label="Confidence")], title="🧠 CT Stroke + Subtype Classifier" ) demo = gr.TabbedInterface([mri_ui, ct_ui], ["MRI Classifier", "CT Classifier"]) demo.launch()