Ahmed-El-Sharkawy commited on
Commit
28b8414
·
verified ·
1 Parent(s): f99c21b

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -133
  2. main.py +143 -0
app.py DELETED
@@ -1,133 +0,0 @@
1
- from flask import Flask, request, jsonify
2
- import pandas as pd
3
- import numpy as np
4
- from data import StrokeData,HeartData
5
- from sklearn.preprocessing import LabelEncoder
6
- from sklearn.preprocessing import StandardScaler
7
- from sklearn.ensemble import RandomForestClassifier
8
-
9
- import joblib
10
- import pickle
11
-
12
-
13
- class HeartData:
14
- def __init__(self, age, sex, chest_pain_type, resting_bp, restecg, max_hr, exang, oldpeak, slope, thal):
15
- self.features = [age, sex, chest_pain_type, resting_bp, restecg, max_hr, exang, oldpeak, slope, thal]
16
-
17
- class StrokeData:
18
- def __init__(self, age, hypertension, heart_disease, ever_married, work_type, avg_glucose_level, bmi, smoking_status):
19
- self.features = [age, hypertension, heart_disease, ever_married, work_type, avg_glucose_level, bmi, smoking_status]
20
-
21
- class HealthPredictor:
22
- def __init__(self):
23
- self.heart_model_path = 'Heart_Disease/Saved_Model_Status/HeartModelRandomForest'
24
- self.heart_scaler_path = 'Heart_Disease/Saved_Model_Status/Standard_scaler.pkl'
25
- self.stroke_model_path = 'Stroke_Code/Saved_Model_Status/StrokeModelRandomForest'
26
- self.stroke_scaler_path = 'Stroke_Code/Saved_Model_Status/Standard_scaler.pkl'
27
- self.encoders_paths = {
28
- 'ever_married': 'Stroke_Code/Saved_Model_Status/ever_married_encoder.pkl',
29
- 'work_type': 'Stroke_Code/Saved_Model_Status/work_type_encoder.pkl',
30
- 'smoking_status': 'Stroke_Code/Saved_Model_Status/smoking_status_encoder.pkl'
31
- }
32
-
33
- self.heart_model = self.load_model(self.heart_model_path)
34
- self.heart_scaler = self.load_scaler(self.heart_scaler_path)
35
- self.stroke_model = self.load_model(self.stroke_model_path)
36
- self.stroke_scaler = self.load_scaler(self.stroke_scaler_path)
37
-
38
- self.ever_married_encoder = self.load_encoder(self.encoders_paths['ever_married'])
39
- self.work_type_encoder = self.load_encoder(self.encoders_paths['work_type'])
40
- self.smoking_status_encoder = self.load_encoder(self.encoders_paths['smoking_status'])
41
-
42
- def load_model(self, path):
43
- with open(path, 'rb') as file:
44
- return pickle.load(file)
45
-
46
- def load_scaler(self, path):
47
- return joblib.load(path)
48
-
49
- def load_encoder(self, path):
50
- return joblib.load(path)
51
-
52
- def predict_heart(self, data_point):
53
- data_point_scaled = self.heart_scaler.transform(np.array([data_point]))
54
- return self.heart_model.predict(data_point_scaled)[0]
55
-
56
- def predict_stroke(self, data_point):
57
- data_point[3] = self.ever_married_encoder.transform([data_point[3]])[0]
58
- data_point[4] = self.work_type_encoder.transform([data_point[4]])[0]
59
- data_point[7] = self.smoking_status_encoder.transform([data_point[7]])[0]
60
-
61
- data_point_scaled = self.stroke_scaler.transform(np.array([data_point]))
62
-
63
- # Get prediction probability
64
- probabilities = self.stroke_model.predict_proba(data_point_scaled)[0]
65
-
66
- # You can return both prediction and probabilities if needed
67
- prediction = np.argmax(probabilities)
68
- return prediction, probabilities
69
-
70
- # return self.stroke_model.predict(data_point_scaled)[0]
71
-
72
-
73
- class PersonData:
74
- def __init__(self, age, sex, chest_pain_type, resting_bp, restecg, max_hr, exang, oldpeak, slope, thal,
75
- hypertension, ever_married, work_type, avg_glucose_level, bmi, smoking_status):
76
- self.features = [age, sex, chest_pain_type, resting_bp, restecg, max_hr, exang, oldpeak, slope, thal,
77
- hypertension, ever_married, work_type, avg_glucose_level, bmi, smoking_status]
78
- self.predictor = HealthPredictor()
79
- self.heart_prediction = self.predict_heart()
80
- self.stroke_prediction, self.stroke_proba = self.predict_stroke()
81
-
82
- def predict_heart(self):
83
- heart_data = HeartData(*self.features[:10])
84
- return self.predictor.predict_heart(heart_data.features)
85
-
86
- def predict_stroke(self):
87
- self.heart_prediction
88
- stroke_data = StrokeData(self.features[0], self.features[10], self.heart_prediction, self.features[11], self.features[12],
89
- self.features[13], self.features[14], self.features[15])
90
- return self.predictor.predict_stroke(stroke_data.features)
91
-
92
- app = Flask(__name__)
93
-
94
- @app.route('/', methods=['GET'])
95
- def home():
96
- return "✅ Sahha Health Prediction API is Running", 200
97
-
98
- @app.route('/predict', methods=['POST'])
99
- def predict():
100
- try:
101
- data = request.get_json()
102
-
103
- person_data = PersonData(
104
- age=data['age'],
105
- sex=data['sex'],
106
- chest_pain_type=data['chest_pain_type'],
107
- resting_bp=data['resting_bp'],
108
- restecg=data['restecg'],
109
- max_hr=data['max_hr'],
110
- exang=data['exang'],
111
- oldpeak=data['oldpeak'],
112
- slope=data['slope'],
113
- thal=data['thal'],
114
- hypertension=data['hypertension'],
115
- ever_married=data['ever_married'],
116
- work_type=data['work_type'],
117
- avg_glucose_level=data['avg_glucose_level'],
118
- bmi=data['bmi'],
119
- smoking_status=data['smoking_status']
120
- )
121
-
122
- return jsonify({
123
- 'heart_prediction': int(person_data.heart_prediction),
124
- 'stroke_prediction': int(person_data.stroke_prediction),
125
- 'stroke_probability': round(float(person_data.stroke_proba[person_data.stroke_prediction]), 4)
126
- })
127
-
128
- except Exception as e:
129
- return jsonify({'error': str(e)}), 500
130
-
131
- if __name__ == "__main__":
132
- app.run(host='0.0.0.0', port=8087, debug=True)
133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as transforms
7
+ import numpy as np
8
+ import cv2
9
+
10
+ # --- Models ---
11
+ class EnhancedCNN_CT(nn.Module):
12
+ def __init__(self):
13
+ super(EnhancedCNN_CT, self).__init__()
14
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
15
+ self.bn1 = nn.BatchNorm2d(32)
16
+ self.pool1 = nn.MaxPool2d(2)
17
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
18
+ self.bn2 = nn.BatchNorm2d(64)
19
+ self.pool2 = nn.MaxPool2d(2)
20
+ self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
21
+ self.bn3 = nn.BatchNorm2d(128)
22
+ self.pool3 = nn.MaxPool2d(2)
23
+ self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
24
+ self.bn4 = nn.BatchNorm2d(256)
25
+ self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
26
+ self.fc1 = nn.Linear(256, 256)
27
+ self.dropout = nn.Dropout(0.5)
28
+ self.fc2 = nn.Linear(256, 1)
29
+
30
+ def forward(self, x):
31
+ x = self.pool1(F.relu(self.bn1(self.conv1(x))))
32
+ x = self.pool2(F.relu(self.bn2(self.conv2(x))))
33
+ x = self.pool3(F.relu(self.bn3(self.conv3(x))))
34
+ x = self.global_pool(F.relu(self.bn4(self.conv4(x))))
35
+ x = torch.flatten(x, 1)
36
+ x = self.dropout(F.relu(self.fc1(x)))
37
+ return self.fc2(x)
38
+
39
+ class Sub_Class_CNNModel_CT(nn.Module):
40
+ def __init__(self, num_classes=2):
41
+ super(Sub_Class_CNNModel_CT, self).__init__()
42
+ self.features = nn.Sequential(
43
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
44
+ nn.BatchNorm2d(32),
45
+ nn.ReLU(),
46
+ nn.MaxPool2d(2),
47
+ nn.Dropout(0.25),
48
+
49
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
50
+ nn.BatchNorm2d(64),
51
+ nn.ReLU(),
52
+ nn.MaxPool2d(2),
53
+ nn.Dropout(0.25),
54
+
55
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
56
+ nn.BatchNorm2d(128),
57
+ nn.ReLU(),
58
+ nn.MaxPool2d(2),
59
+ nn.Dropout(0.25)
60
+ )
61
+ self.classifier = nn.Sequential(
62
+ nn.Flatten(),
63
+ nn.Linear(128 * 28 * 28, 512),
64
+ nn.BatchNorm1d(512),
65
+ nn.ReLU(),
66
+ nn.Dropout(0.5),
67
+ nn.Linear(512, num_classes)
68
+ )
69
+
70
+ def forward(self, x):
71
+ x = self.features(x)
72
+ x = self.classifier(x)
73
+ return torch.softmax(x, dim=1)
74
+
75
+ def preprocess_ct(img):
76
+ img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
77
+ resized = cv2.resize(img_cv, (224, 224))
78
+ img_pil = Image.fromarray(cv2.cvtColor(resized, cv2.COLOR_BGR2RGB))
79
+ transform = transforms.Compose([transforms.ToTensor()])
80
+ return transform(img_pil).unsqueeze(0)
81
+
82
+ def preprocess_sub_ct(img):
83
+ img = img.convert("RGB")
84
+ transform = transforms.Compose([
85
+ transforms.Resize((224, 224)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
88
+ ])
89
+ return transform(img).unsqueeze(0)
90
+
91
+ # --- Inference Functions ---
92
+ def classify_ct(image):
93
+ model = EnhancedCNN_CT()
94
+ model.load_state_dict(torch.load('CT/best_model_CT.pth', map_location='cpu'))
95
+ model.eval()
96
+ tensor = preprocess_ct(image)
97
+ with torch.no_grad():
98
+ output = model(tensor)
99
+ pred = torch.sigmoid(output).item()
100
+
101
+ if pred < 0.5:
102
+ return ("Normal", 1 - float(pred))
103
+
104
+ sub_model = Sub_Class_CNNModel_CT()
105
+ sub_model.load_state_dict(torch.load('CT/cnn_model_sub_class.pth', map_location='cpu'))
106
+ sub_model.eval()
107
+ tensor_sub = preprocess_sub_ct(image)
108
+ with torch.no_grad():
109
+ sub_output = sub_model(tensor_sub)
110
+ sub_pred = torch.argmax(sub_output, dim=1).item()
111
+ sub_conf = sub_output[0][sub_pred].item()
112
+
113
+ sub_class_names = ['hemorrhagic', 'ischaemic']
114
+ return (f"Stroke - {sub_class_names[sub_pred]}", float(sub_conf))
115
+
116
+ app = Flask(__name__)
117
+
118
+ @app.route('/', methods=['GET'])
119
+ def home():
120
+ return "✅ Sahha Health Prediction API is Running", 200
121
+
122
+ @app.route('/predict_computer_vision', methods=['POST'])
123
+ def predict_computer_vision():
124
+ try:
125
+ if 'image' not in request.files:
126
+ return jsonify({'error': 'No image provided'}), 400
127
+
128
+ file = request.files['image']
129
+ image = Image.open(file.stream)
130
+
131
+ result, confidence = classify_ct(image)
132
+
133
+ return jsonify({
134
+ 'main_prediction': result,
135
+ 'confidence': round(confidence, 4)
136
+ })
137
+
138
+ except Exception as e:
139
+ return jsonify({'error': str(e)}), 500
140
+
141
+ # if __name__ == "__main__":
142
+ # app.run(host='0.0.0.0', port=5000, debug=True)
143
+