Roberta2024 commited on
Commit
bef5c48
·
verified ·
1 Parent(s): 8355d75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -84
app.py CHANGED
@@ -1,85 +1,107 @@
1
- # Import necessary libraries
2
  import pandas as pd
3
- import seaborn as sns
4
- import matplotlib.pyplot as plt
5
- from sklearn.model_selection import train_test_split
6
- from sklearn.linear_model import LogisticRegression
7
- from sklearn.metrics import confusion_matrix, roc_curve, auc
8
- from sklearn.preprocessing import StandardScaler, LabelEncoder
9
- import joblib
10
- import os
11
-
12
- # File path
13
- file_path = "heart.csv"
14
-
15
- # Step 1: Data Cleaning and Encoding
16
- # Load data
17
- data = pd.read_csv(file_path)
18
-
19
- # Handle missing values (example: filling with median)
20
- data = data.fillna(data.median())
21
-
22
- # Encode categorical variables
23
- label_encoders = {}
24
- for column in data.select_dtypes(include=['object']).columns:
25
- le = LabelEncoder()
26
- data[column] = le.fit_transform(data[column])
27
- label_encoders[column] = le
28
-
29
- # Step 2: Plotting the Dependency Matrix
30
- plt.figure(figsize=(12, 8))
31
- correlation_matrix = data.corr()
32
- sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm')
33
- plt.title('Correlation Matrix')
34
- plt.show()
35
-
36
- # Step 3: Supervised Learning Model for Prediction
37
- # Define features and target
38
- X = data.drop('target', axis=1) # Assuming 'target' is the target variable
39
- y = data['target']
40
-
41
- # Split the data
42
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
43
-
44
- # Standardize the data
45
- scaler = StandardScaler()
46
- X_train = scaler.fit_transform(X_train)
47
- X_test = scaler.transform(X_test)
48
-
49
- # Train the model
50
- model = LogisticRegression()
51
- model.fit(X_train, y_train)
52
-
53
- # Make predictions
54
- y_pred = model.predict(X_test)
55
- y_pred_prob = model.predict_proba(X_test)[:, 1]
56
-
57
- # Step 4: Evaluation Using Confusion Matrix and Plotting ROC Curve
58
- # Confusion Matrix
59
- conf_matrix = confusion_matrix(y_test, y_pred)
60
- sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
61
- plt.title('Confusion Matrix')
62
- plt.xlabel('Predicted')
63
- plt.ylabel('Actual')
64
- plt.show()
65
-
66
- # ROC Curve
67
- fpr, tpr, _ = roc_curve(y_test, y_pred_prob)
68
- roc_auc = auc(fpr, tpr)
69
- plt.figure()
70
- plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:0.2f})')
71
- plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
72
- plt.xlim([0.0, 1.0])
73
- plt.ylim([0.0, 1.05])
74
- plt.xlabel('False Positive Rate')
75
- plt.ylabel('True Positive Rate')
76
- plt.title('Receiver Operating Characteristic (ROC) Curve')
77
- plt.legend(loc='lower right')
78
- plt.show()
79
-
80
- # Ensure the directory exists before saving the model
81
- model_directory = './models'
82
- os.makedirs(model_directory, exist_ok=True)
83
- model_filename = os.path.join(model_directory, 'logistic_regression_model.sav')
84
- joblib.dump(model, model_filename)
85
- print(f"Model saved to {model_filename}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import pandas as pd
3
+ import numpy as np
4
+ import joblib, os
5
+
6
+ script_dir = os.path.dirname(os.path.abspath(__file__))
7
+ pipeline_path = os.path.join(script_dir, 'toolkit', 'pipeline.joblib')
8
+ model_path = os.path.join(script_dir, 'toolkit', 'Random Forest Classifier.joblib')
9
+
10
+ # Load transformation pipeline and model
11
+ pipeline = joblib.load(pipeline_path)
12
+ model = joblib.load(model_path)
13
+
14
+ # Update predict function to handle new parameters
15
+ def predict(age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal):
16
+ # Create a dataframe with the input data
17
+ input_df = pd.DataFrame({
18
+ 'age': [age],
19
+ 'sex': [sex],
20
+ 'cp': [cp],
21
+ 'trestbps': [trestbps],
22
+ 'chol': [chol],
23
+ 'fbs': [fbs],
24
+ 'restecg': [restecg],
25
+ 'thalach': [thalach],
26
+ 'exang': [exang],
27
+ 'oldpeak': [oldpeak],
28
+ 'slope': [slope],
29
+ 'ca': [ca],
30
+ 'thal': [thal]
31
+ })
32
+
33
+ # Process input data using the pipeline
34
+ X_processed = pipeline.transform(input_df)
35
+
36
+ # Make predictions using the model
37
+ prediction_probs = model.predict_proba(X_processed)[0]
38
+ prediction_label = {
39
+ "Prediction: CHURN 🔴": prediction_probs[1],
40
+ "Prediction: STAY ✅": prediction_probs[0]
41
+ }
42
+
43
+ return prediction_label
44
+
45
+ input_interface = []
46
+
47
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
48
+
49
+ Title = gr.Label('Customer Churn Prediction App')
50
+
51
+ with gr.Row():
52
+ Title
53
+
54
+ with gr.Row():
55
+ gr.Markdown("This app predicts likelihood of a customer to leave or stay with the company")
56
+
57
+ with gr.Row():
58
+ with gr.Column():
59
+ input_interface_column_1 = [
60
+ gr.components.Slider(label='Age', minimum=0, maximum=120, step=1),
61
+ gr.components.Radio([0, 1], label='Sex'),
62
+ gr.components.Slider(label='Chest Pain Type', minimum=0, maximum=3, step=1),
63
+ gr.components.Slider(label='Resting Blood Pressure', minimum=0, maximum=200, step=1),
64
+ gr.components.Slider(label='Cholesterol', minimum=0, maximum=600, step=1),
65
+ gr.components.Radio([0, 1], label='Fasting Blood Sugar > 120 mg/dl')
66
+ ]
67
+
68
+ with gr.Column():
69
+ input_interface_column_2 = [
70
+ gr.components.Slider(label='Resting ECG', minimum=0, maximum=2, step=1),
71
+ gr.components.Slider(label='Max Heart Rate Achieved', minimum=60, maximum=220, step=1),
72
+ gr.components.Radio([0, 1], label='Exercise Induced Angina'),
73
+ gr.components.Slider(label='ST Depression Induced by Exercise', minimum=0.0, maximum=10.0, step=0.1),
74
+ gr.components.Slider(label='Slope of Peak Exercise ST Segment', minimum=0, maximum=2, step=1),
75
+ gr.components.Slider(label='Number of Major Vessels (0-3)', minimum=0, maximum=3, step=1),
76
+ gr.components.Slider(label='Thalassemia (0-3)', minimum=0, maximum=3, step=1)
77
+ ]
78
+
79
+ with gr.Row():
80
+ input_interface.extend(input_interface_column_1)
81
+ input_interface.extend(input_interface_column_2)
82
+
83
+ with gr.Row():
84
+ predict_btn = gr.Button('Predict')
85
+ output_interface = gr.Label(label="churn")
86
+
87
+ with gr.Accordion("Open for information on inputs", open=False):
88
+ gr.Markdown("""This app receives the following as inputs and processes them to return the prediction on whether a customer, will churn or not.
89
+
90
+ - age: Age of the customer
91
+ - sex: Sex of the customer (0: Female, 1: Male)
92
+ - cp: Chest Pain Type (0: typical angina, 1: atypical angina, 2: non-anginal pain, 3: asymptomatic)
93
+ - trestbps: Resting Blood Pressure (in mm Hg on admission to the hospital)
94
+ - chol: Serum Cholesterol in mg/dl
95
+ - fbs: Fasting Blood Sugar > 120 mg/dl (0: No, 1: Yes)
96
+ - restecg: Resting Electrocardiographic results (0: normal, 1: having ST-T wave abnormality, 2: showing probable or definite left ventricular hypertrophy)
97
+ - thalach: Maximum Heart Rate Achieved
98
+ - exang: Exercise Induced Angina (0: No, 1: Yes)
99
+ - oldpeak: ST depression induced by exercise relative to rest
100
+ - slope: The slope of the peak exercise ST segment
101
+ - ca: Number of major vessels (0-3) colored by fluoroscopy
102
+ - thal: Thalassemia (0: normal, 1: fixed defect, 2: reversible defect, 3: unknown)
103
+ """)
104
+
105
+ predict_btn.click(fn=predict, inputs=input_interface, outputs=output_interface)
106
+
107
+ app.launch(share=True)