Ahmed-El-Sharkawy commited on
Commit
712bb48
·
verified ·
1 Parent(s): 5b38cce

Upload 2 files

Browse files
Files changed (2) hide show
  1. MRI/best_model.pth +3 -0
  2. app-ver-2.py +189 -0
MRI/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54389b37590563b2d9c36901e8cd8ab5e8840460cca8cf792b4ab5a620775bfa
3
+ size 1833570
app-ver-2.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gradio_app.py
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as transforms
8
+ import numpy as np
9
+ import cv2
10
+
11
+ # --- Models ---
12
+ class EnhancedCNN_MRI(nn.Module):
13
+ def __init__(self):
14
+ super(EnhancedCNN_MRI, self).__init__()
15
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
16
+ self.bn1 = nn.BatchNorm2d(32)
17
+ self.pool1 = nn.MaxPool2d(2)
18
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
19
+ self.bn2 = nn.BatchNorm2d(64)
20
+ self.pool2 = nn.MaxPool2d(2)
21
+ self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
22
+ self.bn3 = nn.BatchNorm2d(128)
23
+ self.pool3 = nn.MaxPool2d(2)
24
+ self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
25
+ self.bn4 = nn.BatchNorm2d(256)
26
+ self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
27
+ self.fc1 = nn.Linear(256, 256)
28
+ self.dropout = nn.Dropout(0.5)
29
+ self.fc2 = nn.Linear(256, 1)
30
+
31
+ def forward(self, x):
32
+ x = self.pool1(F.relu(self.bn1(self.conv1(x))))
33
+ x = self.pool2(F.relu(self.bn2(self.conv2(x))))
34
+ x = self.pool3(F.relu(self.bn3(self.conv3(x))))
35
+ x = self.global_pool(F.relu(self.bn4(self.conv4(x))))
36
+ x = torch.flatten(x, 1)
37
+ x = self.dropout(F.relu(self.fc1(x)))
38
+ return self.fc2(x)
39
+
40
+
41
+ class EnhancedCNN_CT(nn.Module):
42
+ def __init__(self):
43
+ super(EnhancedCNN_CT, self).__init__()
44
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
45
+ self.bn1 = nn.BatchNorm2d(32)
46
+ self.pool1 = nn.MaxPool2d(2)
47
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
48
+ self.bn2 = nn.BatchNorm2d(64)
49
+ self.pool2 = nn.MaxPool2d(2)
50
+ self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
51
+ self.bn3 = nn.BatchNorm2d(128)
52
+ self.pool3 = nn.MaxPool2d(2)
53
+ self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
54
+ self.bn4 = nn.BatchNorm2d(256)
55
+ self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
56
+ self.fc1 = nn.Linear(256, 256)
57
+ self.dropout = nn.Dropout(0.5)
58
+ self.fc2 = nn.Linear(256, 1)
59
+
60
+ def forward(self, x):
61
+ x = self.pool1(F.relu(self.bn1(self.conv1(x))))
62
+ x = self.pool2(F.relu(self.bn2(self.conv2(x))))
63
+ x = self.pool3(F.relu(self.bn3(self.conv3(x))))
64
+ x = self.global_pool(F.relu(self.bn4(self.conv4(x))))
65
+ x = torch.flatten(x, 1)
66
+ x = self.dropout(F.relu(self.fc1(x)))
67
+ return self.fc2(x)
68
+
69
+
70
+ class Sub_Class_CNNModel_CT(nn.Module):
71
+ def __init__(self, num_classes=2):
72
+ super(Sub_Class_CNNModel_CT, self).__init__()
73
+ self.features = nn.Sequential(
74
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
75
+ nn.BatchNorm2d(32),
76
+ nn.ReLU(),
77
+ nn.MaxPool2d(2),
78
+ nn.Dropout(0.25),
79
+
80
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
81
+ nn.BatchNorm2d(64),
82
+ nn.ReLU(),
83
+ nn.MaxPool2d(2),
84
+ nn.Dropout(0.25),
85
+
86
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
87
+ nn.BatchNorm2d(128),
88
+ nn.ReLU(),
89
+ nn.MaxPool2d(2),
90
+ nn.Dropout(0.25)
91
+ )
92
+ self.classifier = nn.Sequential(
93
+ nn.Flatten(),
94
+ nn.Linear(128 * 28 * 28, 512),
95
+ nn.BatchNorm1d(512),
96
+ nn.ReLU(),
97
+ nn.Dropout(0.5),
98
+ nn.Linear(512, num_classes)
99
+ )
100
+
101
+ def forward(self, x):
102
+ x = self.features(x)
103
+ x = self.classifier(x)
104
+ return torch.softmax(x, dim=1)
105
+
106
+
107
+ # --- Preprocessing ---
108
+ def preprocess_mri(img):
109
+ img = img.convert("L")
110
+ transform = transforms.Compose([
111
+ transforms.Resize((224, 224)),
112
+ transforms.ToTensor()
113
+ ])
114
+ return transform(img).unsqueeze(0)
115
+
116
+
117
+ def preprocess_ct(img):
118
+ img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
119
+ resized = cv2.resize(img_cv, (224, 224))
120
+ img_pil = Image.fromarray(cv2.cvtColor(resized, cv2.COLOR_BGR2RGB))
121
+ transform = transforms.Compose([transforms.ToTensor()])
122
+ return transform(img_pil).unsqueeze(0)
123
+
124
+
125
+ def preprocess_sub_ct(img):
126
+ img = img.convert("RGB")
127
+ transform = transforms.Compose([
128
+ transforms.Resize((224, 224)),
129
+ transforms.ToTensor(),
130
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
131
+ ])
132
+ return transform(img).unsqueeze(0)
133
+
134
+
135
+ # --- Inference Functions ---
136
+ def classify_mri(image):
137
+ model = EnhancedCNN_MRI()
138
+ model.load_state_dict(torch.load('MRI/best_model.pth', map_location='cpu'))
139
+ model.eval()
140
+ tensor = preprocess_mri(image)
141
+ with torch.no_grad():
142
+ output = model(tensor)
143
+ pred = torch.sigmoid(output).item()
144
+ return ("Stroke", float(pred)) if pred >= 0.5 else ("Normal", 1 - float(pred))
145
+
146
+
147
+ def classify_ct(image):
148
+ model = EnhancedCNN_CT()
149
+ model.load_state_dict(torch.load('CT/best_model_CT.pth', map_location='cpu'))
150
+ model.eval()
151
+ tensor = preprocess_ct(image)
152
+ with torch.no_grad():
153
+ output = model(tensor)
154
+ pred = torch.sigmoid(output).item()
155
+
156
+ if pred < 0.5:
157
+ return ("Normal", 1 - float(pred))
158
+
159
+ sub_model = Sub_Class_CNNModel_CT()
160
+ sub_model.load_state_dict(torch.load('CT/cnn_model_sub_class.pth', map_location='cpu'))
161
+ sub_model.eval()
162
+ tensor_sub = preprocess_sub_ct(image)
163
+ with torch.no_grad():
164
+ sub_output = sub_model(tensor_sub)
165
+ sub_pred = torch.argmax(sub_output, dim=1).item()
166
+ sub_conf = sub_output[0][sub_pred].item()
167
+
168
+ sub_class_names = ['hemorrhagic', 'ischaemic']
169
+ return (f"Stroke - {sub_class_names[sub_pred]}", float(sub_conf))
170
+
171
+
172
+ # --- Gradio Interface ---
173
+ mri_ui = gr.Interface(
174
+ fn=classify_mri,
175
+ inputs=gr.Image(type="pil"),
176
+ outputs=[gr.Label(label="Prediction"), gr.Number(label="Confidence")],
177
+ title="🧠 MRI Stroke Classifier"
178
+ )
179
+
180
+ ct_ui = gr.Interface(
181
+ fn=classify_ct,
182
+ inputs=gr.Image(type="pil"),
183
+ outputs=[gr.Label(label="Prediction"), gr.Number(label="Confidence")],
184
+ title="🧠 CT Stroke + Subtype Classifier"
185
+ )
186
+
187
+ demo = gr.TabbedInterface([mri_ui, ct_ui], ["MRI Classifier", "CT Classifier"])
188
+
189
+ demo.launch()