Abhishek0323 commited on
Commit
bc742f3
·
verified ·
1 Parent(s): ac13b6c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import requests
4
+ import pandas as pd
5
+ import json
6
+
7
+ # Function to create the payload for the model
8
+ def create_tf_serving_json(data):
9
+ return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}
10
+
11
+ # Function to send a request to the model endpoint
12
+ def score_model(dataset):
13
+ url = 'https://adb-3810412827421523.3.azuredatabricks.net/serving-endpoints/churnpredictionmodel0323/invocations'
14
+ headers = {'Authorization': f'Bearer {os.environ.get("dapi0c3745fa836d2634501e12bde7463bb1-2")}', 'Content-Type': 'application/json'}
15
+ ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
16
+ data_json = json.dumps(ds_dict, allow_nan=True)
17
+ response = requests.request(method='POST', headers=headers, url=url, data=data_json)
18
+ if response.status_code != 200:
19
+ raise Exception(f'Request failed with status {response.status_code}, {response.text}')
20
+ return response.json()
21
+
22
+ # Streamlit app UI
23
+ st.title('Employee Churn Prediction')
24
+
25
+ # Create a form
26
+ with st.form(key='churn_form'):
27
+ satisfaction_level = st.slider('Satisfaction Level', 0.0, 1.0, 0.5)
28
+ last_evaluation = st.slider('Last Evaluation', 0.0, 1.0, 0.5)
29
+ number_project = st.slider('Number of Projects', 1, 10, 3)
30
+ average_montly_hours = st.slider('Average Monthly Hours', 50, 350, 200)
31
+ time_spend_company = st.slider('Time Spent in Company (years)', 1, 10, 3)
32
+ work_accident = st.selectbox('Work Accident', [0, 1])
33
+ promotion_last_5years = st.selectbox('Promotion in Last 5 Years', [0, 1])
34
+ salary = st.selectbox('Salary Level', ['Low', 'Medium', 'High'])
35
+
36
+ # Encode salary into one-hot vectors
37
+ salaryVec_0, salaryVec_1 = 0.0, 0.0
38
+ if salary == 'Low':
39
+ salaryVec_0 = 1.0
40
+ elif salary == 'Medium':
41
+ salaryVec_1 = 1.0
42
+
43
+ # Submit button
44
+ submit_button = st.form_submit_button(label='Predict Churn')
45
+
46
+ # Handle form submission
47
+ if submit_button:
48
+ # Create a DataFrame with the input data
49
+ input_data = pd.DataFrame({
50
+ 'satisfaction_level': [satisfaction_level],
51
+ 'last_evaluation': [last_evaluation],
52
+ 'number_project': [number_project],
53
+ 'average_montly_hours': [average_montly_hours],
54
+ 'time_spend_company': [time_spend_company],
55
+ 'Work_accident': [work_accident],
56
+ 'promotion_last_5years': [promotion_last_5years],
57
+ 'salaryVec_0': [salaryVec_0],
58
+ 'salaryVec_1': [salaryVec_1]
59
+ })
60
+
61
+ # Get prediction from the model
62
+ try:
63
+ prediction = score_model(input_data)
64
+ churn_prediction = prediction['predictions'][0]
65
+ if churn_prediction == 1:
66
+ st.write('The employee is likely to churn.')
67
+ else:
68
+ st.write('The employee is not likely to churn.')
69
+ except Exception as e:
70
+ st.error(f'Error: {e}')