Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,85 +1,107 @@
|
|
1 |
-
|
2 |
import pandas as pd
|
3 |
-
import
|
4 |
-
import
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
#
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import joblib, os
|
5 |
+
|
6 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
7 |
+
pipeline_path = os.path.join(script_dir, 'toolkit', 'pipeline.joblib')
|
8 |
+
model_path = os.path.join(script_dir, 'toolkit', 'Random Forest Classifier.joblib')
|
9 |
+
|
10 |
+
# Load transformation pipeline and model
|
11 |
+
pipeline = joblib.load(pipeline_path)
|
12 |
+
model = joblib.load(model_path)
|
13 |
+
|
14 |
+
# Update predict function to handle new parameters
|
15 |
+
def predict(age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal):
|
16 |
+
# Create a dataframe with the input data
|
17 |
+
input_df = pd.DataFrame({
|
18 |
+
'age': [age],
|
19 |
+
'sex': [sex],
|
20 |
+
'cp': [cp],
|
21 |
+
'trestbps': [trestbps],
|
22 |
+
'chol': [chol],
|
23 |
+
'fbs': [fbs],
|
24 |
+
'restecg': [restecg],
|
25 |
+
'thalach': [thalach],
|
26 |
+
'exang': [exang],
|
27 |
+
'oldpeak': [oldpeak],
|
28 |
+
'slope': [slope],
|
29 |
+
'ca': [ca],
|
30 |
+
'thal': [thal]
|
31 |
+
})
|
32 |
+
|
33 |
+
# Process input data using the pipeline
|
34 |
+
X_processed = pipeline.transform(input_df)
|
35 |
+
|
36 |
+
# Make predictions using the model
|
37 |
+
prediction_probs = model.predict_proba(X_processed)[0]
|
38 |
+
prediction_label = {
|
39 |
+
"Prediction: CHURN 🔴": prediction_probs[1],
|
40 |
+
"Prediction: STAY ✅": prediction_probs[0]
|
41 |
+
}
|
42 |
+
|
43 |
+
return prediction_label
|
44 |
+
|
45 |
+
input_interface = []
|
46 |
+
|
47 |
+
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
48 |
+
|
49 |
+
Title = gr.Label('Customer Churn Prediction App')
|
50 |
+
|
51 |
+
with gr.Row():
|
52 |
+
Title
|
53 |
+
|
54 |
+
with gr.Row():
|
55 |
+
gr.Markdown("This app predicts likelihood of a customer to leave or stay with the company")
|
56 |
+
|
57 |
+
with gr.Row():
|
58 |
+
with gr.Column():
|
59 |
+
input_interface_column_1 = [
|
60 |
+
gr.components.Slider(label='Age', minimum=0, maximum=120, step=1),
|
61 |
+
gr.components.Radio([0, 1], label='Sex'),
|
62 |
+
gr.components.Slider(label='Chest Pain Type', minimum=0, maximum=3, step=1),
|
63 |
+
gr.components.Slider(label='Resting Blood Pressure', minimum=0, maximum=200, step=1),
|
64 |
+
gr.components.Slider(label='Cholesterol', minimum=0, maximum=600, step=1),
|
65 |
+
gr.components.Radio([0, 1], label='Fasting Blood Sugar > 120 mg/dl')
|
66 |
+
]
|
67 |
+
|
68 |
+
with gr.Column():
|
69 |
+
input_interface_column_2 = [
|
70 |
+
gr.components.Slider(label='Resting ECG', minimum=0, maximum=2, step=1),
|
71 |
+
gr.components.Slider(label='Max Heart Rate Achieved', minimum=60, maximum=220, step=1),
|
72 |
+
gr.components.Radio([0, 1], label='Exercise Induced Angina'),
|
73 |
+
gr.components.Slider(label='ST Depression Induced by Exercise', minimum=0.0, maximum=10.0, step=0.1),
|
74 |
+
gr.components.Slider(label='Slope of Peak Exercise ST Segment', minimum=0, maximum=2, step=1),
|
75 |
+
gr.components.Slider(label='Number of Major Vessels (0-3)', minimum=0, maximum=3, step=1),
|
76 |
+
gr.components.Slider(label='Thalassemia (0-3)', minimum=0, maximum=3, step=1)
|
77 |
+
]
|
78 |
+
|
79 |
+
with gr.Row():
|
80 |
+
input_interface.extend(input_interface_column_1)
|
81 |
+
input_interface.extend(input_interface_column_2)
|
82 |
+
|
83 |
+
with gr.Row():
|
84 |
+
predict_btn = gr.Button('Predict')
|
85 |
+
output_interface = gr.Label(label="churn")
|
86 |
+
|
87 |
+
with gr.Accordion("Open for information on inputs", open=False):
|
88 |
+
gr.Markdown("""This app receives the following as inputs and processes them to return the prediction on whether a customer, will churn or not.
|
89 |
+
|
90 |
+
- age: Age of the customer
|
91 |
+
- sex: Sex of the customer (0: Female, 1: Male)
|
92 |
+
- cp: Chest Pain Type (0: typical angina, 1: atypical angina, 2: non-anginal pain, 3: asymptomatic)
|
93 |
+
- trestbps: Resting Blood Pressure (in mm Hg on admission to the hospital)
|
94 |
+
- chol: Serum Cholesterol in mg/dl
|
95 |
+
- fbs: Fasting Blood Sugar > 120 mg/dl (0: No, 1: Yes)
|
96 |
+
- restecg: Resting Electrocardiographic results (0: normal, 1: having ST-T wave abnormality, 2: showing probable or definite left ventricular hypertrophy)
|
97 |
+
- thalach: Maximum Heart Rate Achieved
|
98 |
+
- exang: Exercise Induced Angina (0: No, 1: Yes)
|
99 |
+
- oldpeak: ST depression induced by exercise relative to rest
|
100 |
+
- slope: The slope of the peak exercise ST segment
|
101 |
+
- ca: Number of major vessels (0-3) colored by fluoroscopy
|
102 |
+
- thal: Thalassemia (0: normal, 1: fixed defect, 2: reversible defect, 3: unknown)
|
103 |
+
""")
|
104 |
+
|
105 |
+
predict_btn.click(fn=predict, inputs=input_interface, outputs=output_interface)
|
106 |
+
|
107 |
+
app.launch(share=True)
|