Abhishek0323 commited on
Commit
543adc9
·
verified ·
1 Parent(s): 9c77632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import streamlit as st
 
2
  import requests
3
  import pandas as pd
4
  import json
 
5
 
6
  # Function to create the payload for the model
7
  def create_tf_serving_json(data):
8
  return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}
9
 
10
  # Function to send a request to the model endpoint
11
- def score_model(dataset):
12
- url = 'Add_Model_URL'
13
- token = 'Add_Databricks_Token'
14
  headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
15
  ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
16
  data_json = json.dumps(ds_dict, allow_nan=True)
@@ -19,6 +19,18 @@ def score_model(dataset):
19
  raise Exception(f'Request failed with status {response.status_code}, {response.text}')
20
  return response.json()
21
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Streamlit app UI
23
  st.title('Employee Churn Prediction')
24
 
@@ -37,7 +49,7 @@ with st.form(key='churn_form'):
37
  salaryVec_0, salaryVec_1 = 0.0, 0.0
38
  if salary == 'Low':
39
  salaryVec_0 = 1.0
40
- elif salary == 'Medium':
41
  salaryVec_1 = 1.0
42
 
43
  # Submit button
@@ -58,10 +70,20 @@ if submit_button:
58
  'salaryVec_1': [salaryVec_1]
59
  })
60
 
61
- # Get prediction from the model
 
 
 
62
  try:
63
- prediction = score_model(input_data)
64
- churn_prediction = prediction['predictions'][0]
 
 
 
 
 
 
 
65
  if churn_prediction == 1:
66
  st.write('The employee is likely to churn.')
67
  else:
 
1
  import streamlit as st
2
+ import os
3
  import requests
4
  import pandas as pd
5
  import json
6
+ import joblib # Import joblib to load the local model
7
 
8
  # Function to create the payload for the model
9
  def create_tf_serving_json(data):
10
  return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}
11
 
12
  # Function to send a request to the model endpoint
13
+ def score_model(dataset, url, token):
 
 
14
  headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
15
  ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
16
  data_json = json.dumps(ds_dict, allow_nan=True)
 
19
  raise Exception(f'Request failed with status {response.status_code}, {response.text}')
20
  return response.json()
21
 
22
+ # Load local model function
23
+ def load_local_model():
24
+ return joblib.load('Random Forest - Grid Search_model.pkl')
25
+
26
+ # Function to make predictions with local model
27
+ def predict_with_local_model(model, input_data):
28
+ input_data = input_data[['satisfaction_level', 'last_evaluation', 'number_project',
29
+ 'average_montly_hours', 'time_spend_company', 'Work_accident',
30
+ 'promotion_last_5years', 'salaryVec_0', 'salaryVec_1']]
31
+ prediction = model.predict(input_data)
32
+ return prediction[0]
33
+
34
  # Streamlit app UI
35
  st.title('Employee Churn Prediction')
36
 
 
49
  salaryVec_0, salaryVec_1 = 0.0, 0.0
50
  if salary == 'Low':
51
  salaryVec_0 = 1.0
52
+ elif salary is 'Medium':
53
  salaryVec_1 = 1.0
54
 
55
  # Submit button
 
70
  'salaryVec_1': [salaryVec_1]
71
  })
72
 
73
+ # Check if URL and token are specified
74
+ url = 'Add_Model_URL'
75
+ token = 'Add_Databricks_Token'
76
+
77
  try:
78
+ if url != 'Add_Model_URL' and token != 'Add_Databricks_Token':
79
+ # Use the API if URL and token are specified
80
+ prediction = score_model(input_data, url, token)
81
+ churn_prediction = prediction['predictions'][0]
82
+ else:
83
+ # Use the local model if URL and token are not specified
84
+ local_model = load_local_model()
85
+ churn_prediction = predict_with_local_model(local_model, input_data)
86
+
87
  if churn_prediction == 1:
88
  st.write('The employee is likely to churn.')
89
  else: